In [None]:
# Utils
from glob import glob

# Numbers
import numpy as np

# Visualization
import matplotlib.pyplot as plt
from IPython.display import Image, Audio, HTML
import librosa.display

# Machine learning
import torch
import torchaudio
from torchaudio.functional import resample
from sklearn.model_selection import train_test_split
from torch import nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader, random_split, TensorDataset, DataLoader

# Audio
import torchaudio
import librosa

In [None]:
class Encoder(nn.Module):
    def __init__(self, layers_size):
        super().__init__()
        layers = []
        for i in range(len(layers_size)-1):
            layers.append(nn.Linear(layers_size[i], layers_size[i+1]))
            layers.append(nn.BatchNorm1d(layers_size[i+1]))
            layers.append(nn.ELU())
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [None]:
class Decoder(nn.Module):
    def __init__(self, layers_size):
        super().__init__()
        layers = []
        for i in range(len(layers_size)-1, 1, -1):
            layers.append(nn.Linear(layers_size[i], layers_size[i-1]))
            layers.append(nn.BatchNorm1d(layers_size[i-1]))
            layers.append(nn.ELU())
        layers.append(nn.Linear(layers_size[1], layers_size[0]))
        layers.append(nn.BatchNorm1d(layers_size[0]))
        # layers.append(nn.ReLU())
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [None]:
class AutoEncoder(pl.LightningModule):
    def __init__(self, layers_size, sample_rate=22050, duration=120, hop_length_ms=20, Sclip=-60, learning_rate=0.001):
        super().__init__()
        self.layers_size = layers_size
        self.encoder = Encoder(self.layers_size)
        self.decoder = Decoder(self.layers_size)
        self.sample_rate = sample_rate
        self.duration = duration
        self.hop_length = int(sample_rate * hop_length_ms / 1000)
        self.Sclip = Sclip
        self.win_length = 4 * self.hop_length
        self.learning_rate = learning_rate
        
    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.MSELoss()(y_hat, x)
        self.log("loss", loss, on_epoch=True, on_step=False)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    def train_model(self, epochs=120, batch_size=512):
        X_train, X_val, y_train, y_val = train_test_split(self.X, self.y,  test_size=0.05, shuffle=True)
        silence = np.zeros([int(X_train.shape[0]*0.1), X_train.shape[1]])
        X_data = torch.tensor(np.vstack([silence, X_train])).float()
        trainer = pl.Trainer(max_epochs=epochs)
        dataset = TensorDataset(X_data, X_data)
        dataloader = DataLoader(dataset, batch_size=batch_size)
        trainer.fit(self, dataloader)

    def load_checkpoint(self, path):
        self.load_state_dict(torch.load(path)['state_dict'])

    def predict(self, specgram):
        specgram = specgram / self.X_max
        S_hat = self(specgram)
        return S_hat * self.X_max

    def format_waveform(self, waveform, original_sr):
        waveform = resample(waveform, original_sr, self.sample_rate) # Resample to custom sample rate
        waveform = torch.mean(waveform, dim=0) # Convert to mono
        waveform = waveform[:self.sample_rate * self.duration] # Trim to custom duration
        return waveform

    def get_waveform(self, path):
        waveform, original_sr = torchaudio.load(path)
        waveform = self.format_waveform(waveform, original_sr)
        return waveform

    def fourier(self, waveform):
        F = torch.stft(waveform, n_fft=self.win_length, hop_length=self.hop_length, win_length=self.win_length, return_complex=True).T
        S = 10*torch.log10(torch.abs(F)**2)
        S = S.clip(self.Sclip, None) - self.Sclip
        return torch.angle(F), S

    def get_specgram(self, path):
        waveform = self.get_waveform(path)
        phase, S = self.fourier(waveform)
        return phase, S

    def load_training_audio(self, path):
        path += "/*.wav"
        self.X = []
        self.y = []
        self.phases = []
        for i, filename in enumerate(glob(path)):
            phase, S = self.get_specgram(filename)
            self.phases.append(phase)
            self.X.append(S)
            self.y.append(torch.ones(S.shape[0]) * i)
        self.phases = torch.vstack(self.phases)
        self.X = torch.vstack(self.X)
        self.X_max = self.X.max()
        self.X = self.X / self.X_max
        self.y = torch.hstack(self.y)

    def plot_specgram(self, specgram):
        plt.figure(figsize=(14, 4))
        librosa.display.specshow(specgram.detach().numpy().T, y_axis='linear', x_axis='time', hop_length=self.hop_length);
        plt.colorbar();

    def get_player(self, phase, specgram):
        S_hat = torch.sqrt(10**((specgram+self.Sclip)/10))*torch.exp(1j*phase)
        waveform_hat = torch.istft(S_hat.T, n_fft=self.win_length, hop_length=self.hop_length, win_length=self.win_length)
        display(Audio(waveform_hat.detach().numpy(), rate=self.sample_rate))

In [None]:
layers_size = [883, 512, 256, 128, 64, 32, 16, 8]
autoencoder = AutoEncoder(layers_size, sample_rate=22050, duration=120, hop_length_ms=20, Sclip=-60, learning_rate=0.001)

In [None]:
og_phase, specgram = autoencoder.get_specgram("wavs/audio.wav")
autoencoder.get_player(og_phase, specgram)
autoencoder.plot_specgram(specgram)

In [None]:
training_audio_path = "wavs"
autoencoder.load_training_audio(training_audio_path)

In [None]:
autoencoder.train_model(epochs=1024, batch_size=512)

In [None]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/