In [None]:
import numpy as np
import matplotlib.pyplot as plt
import IPython.display as ipd
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from scipy import signal
import librosa
import time
from scipy.io import wavfile
import glob
from torch.optim.lr_scheduler import StepLR

In [None]:
class AudioMix(Dataset):
    def __init__(self, folder, n_examples, clip_len, sr=16000):
        """
        Setup data of audio mixtures
        
        Parameters
        ----------
        folder: string
            Folder containing tracks
        n_examples: int
            Number of examples to create in a batch
        clip_len: float
            Length of a clip, in samples
        sr: int
            Sample rate
        """
        self.n_examples = n_examples
        self.clip_len = clip_len
        self.sr = sr
        files = glob.glob("{}/*.mp3".format(folder))
        self.x = []
        for f in files:
            xi, _ = librosa.load(f, sr=sr)
            self.x.append(xi)
    
    def __len__(self):
        return self.n_examples
    
    def __getitem__(self, idx):
        """
        S is "source", M is "mixture"
        """
        n_clips = len(self.x)
        mix = np.random.rand(n_clips)
        mix = mix/np.sum(mix)
        N = len(self.x[0])
        i1 = np.random.randint(N-self.clip_len)
        S = [np.array(mi*xi[i1:i1+self.clip_len], dtype=np.float32) for mi, xi in zip(mix, self.x)]
        M = np.zeros_like(S[0]) # Mixture audio
        for xi in S:
            M += xi
        S = torch.from_numpy(np.array(S, dtype=np.float32))
        M = torch.from_numpy(M[None, :])
        return M, S
        
data = AudioMix("Aha", 2000, 16384)
M, S = data[0]
print("S.shape", S.shape)
print("M.shape", M.shape)

In [None]:
loader = DataLoader(data, shuffle=True, batch_size=16)
M, S = next(iter(loader))
print(S.shape)
conv = nn.Conv1d(1, 24, 15, stride=1, padding=7, bias=False)
Mc = conv(M)
print(Mc.shape)

In [None]:
class Decimate(nn.Module):
    """
    Decimate by a factor of fac across the time axis (axis=1)
    """
    def __init__(self, fac):
        super(Decimate, self).__init__()
        self.fac = fac
    
    def forward(self, X):
        return X[:, :, 0::self.fac]

class WaveUNet(nn.Module):
    def __init__(self, C, L=12, Fc=24, fd=15, fu=5):
        """
        Parameters
        ----------
        C: int
            Number of mixture components
        L: int
            Number of layers
        Fc: int
            Number of extra filters per layer
        fd: int
            Kernel size for downsampling
        fu: int
            Kernel size for upsampling
        """
        super(WaveUNet, self).__init__()
        self.C = C
        
        ## Step 1: Create the convolutional down layers
        
        
        ## Step 2: Create the convolutional up layers
        
        
        ## Step 3: Create the last layer
        
    
    def forward(self, M, verbose=False):
        relu = nn.LeakyReLU()
        downsample = Decimate(2)
        upsample = nn.Upsample(scale_factor=2, mode='linear')
        
    
model = WaveUNet(len(data.x))
S_est = model(M)
print(S_est.shape)

# Train Loop

In [None]:
# Try to use the GPU
device = 'cuda'

def get_data(path, sr=16000):
    # Test data
    M_test, sr = librosa.load(path, sr=sr)
    # Round down to nearest power of 2
    N = 2**int(np.floor(np.log2(M_test.size)))
    M_test = M_test[0:N]
    M_test = np.array(M_test[None, None, :], dtype=np.float32)
    M_test = torch.from_numpy(M_test).to(device)
    return M_test

#M_test_real = get_data("TakeOnMe.mp3")
M_test = get_data("TakeOnMeMidiMix.mp3")

In [None]:
# Model
model = WaveUNet(len(data.x), L=11)
model = model.to(device)

## Step 3: Setup the loss function
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

n_epochs = 200 # Each "epoch" is a loop through the entire dataset
# and we use this to update the parameters
train_losses = []

scheduler = StepLR(optimizer, step_size=20, gamma=0.9)

for epoch in range(n_epochs):
    loader = DataLoader(data, batch_size=16, shuffle=True)
    train_loss = 0
    for M, S in loader: # Go through each mini batch
        # Move inputs/outputs to GPU
        M = M.to(device)
        S = S.to(device)
        # Reset the optimizer's gradients
        optimizer.zero_grad()
        # Run the sequential model on all inputs
        S_est = model(M)
        # Compute the loss function comparing S_est to S
        loss = torch.sum((S_est - S)**2)
        # Compute the gradients of the loss function with respect
        # to all of the parameters of the model
        loss.backward()
        # Update the parameters based on the gradient and
        # the optimization scheme
        optimizer.step()
        train_loss += loss.item()
    
    print("Epoch {}, loss {:.3f}".format(epoch, train_loss))
    train_losses.append(train_loss)
    scheduler.step()
    
    S_test = model(M_test).detach().cpu().numpy()
    for i in range(S_test.shape[1]):
        s = S_test[0, i, :]
        s = np.array(s*32768/np.max(np.abs(s)), dtype=np.int16)
        wavfile.write("track{}.wav".format(i), 16000, s)
    

In [None]:
plt.plot(train_losses)