In [30]:
#importing basic libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
%matplotlib inline

#Importing pytorch functions and modules
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(torch.cuda.get_device_name(0))
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
import torchvision
import torchvision.datasets as datasets
from torchvision import transforms

import librosa
from librosa.core import stft,istft
from sklearn.metrics import *
from math import log

#Setting random seed for reproducibility
SEED = 1234
torch.manual_seed(SEED)
np.random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

GeForce GTX 1050


In [31]:
s, sr=librosa.load("train_clean_male.wav", sr=None)
S_complex=stft(s, n_fft=1024, hop_length=512)
x, sr=librosa.load("train_dirty_male.wav", sr=None)
X_complex=stft(x, n_fft=1024, hop_length=512)



print(X_complex.shape)
print(S_complex.shape)

X = np.abs(X_complex.T)
S = np.abs(S_complex.T)

X_dir = np.divide(X_complex,X.T)
X = torch.tensor(X).to(device)
S = torch.tensor(S).to(device)

print(X.shape)

(513, 2459)
(513, 2459)
torch.Size([2459, 513])


In [32]:
#Defining shapes
H = 1
D = X.shape[1]
K = 1

H_out = 1
D_out = S.shape[1]
K_out = 1
n_data = X.shape[0]

In [33]:
class Network1d(nn.Module):
    def __init__(
            self,
            K=K,
            H=H,
            D=D,
            K_out=K_out,
            H_out=H_out,
            D_out=D_out,
            dropout=0.2,
    ):
        super(Network1d, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.relu = F.relu
        
        self.K_hidden1 = 5
        self.D_hidden1 = 256
        self.stride = 2
        self.maxp1_K = 2

        self.D_last = self.get_conv1d_shape(D_in=D,
                                            K_hidden=self.K_hidden1,
                                            D_hidden=self.D_hidden1,
                                            stride=self.stride,
                                            maxp_K=self.maxp1_K)
        
        self.conv1 = nn.Conv1d(in_channels=K,
                               out_channels=self.K_hidden1,
                               kernel_size=self.D_hidden1,
                               stride=self.stride)
        self.maxp1 = nn.MaxPool1d(kernel_size=2)
        self.fc1 = nn.Linear(self.K_hidden1*self.D_last,D_out)
        
        

    def forward(self, x):
        
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxp1(x)
        x = self.dropout(x)
        
        x = x.view(-1,self.K_hidden1*self.D_last)
        x = self.fc1(x)
        
        return(x)
    
    def get_conv1d_shape(self,D_in,K_hidden,D_hidden,stride,maxp_K):
        l_out = (D_in - (D_hidden-1) -1)/(stride) + 1
        l_out = int(l_out)
        l_ans = (l_out - (maxp_K-1) -1)/(maxp_K) + 1
        l_ans = int(l_ans)
        return(l_ans)

In [34]:
def train_neural_network(model,train_dataset,val_dataset,epochs,early_stopping_rounds,batch_size,
                         learning_rate,verbose,criterion,eval_func,device):
  
  
    #Dictionary where all the important outputs will be kept
    result_dict = dict()
    #Reading number of samples in each set
    n_train = len(train_dataset)
    n_val = len(val_dataset)

      # Data loader, using custom user provided batch size
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                                 batch_size=batch_size, 
                                                 shuffle=True)

    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, 
                                                 batch_size=n_val, 
                                                 shuffle=True)



    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


    # Train the model
    epoch=1
    stop=0
    best_performance = -100000
    rounds = 0
    stop = False
    model = model.to(device)

      #Training while loop 
    if(verbose==True):
        print("Training commenced")
    while ((epoch <= epochs)and(stop==False)):
        train_loss = 0
        for batch, labels in train_loader:  
            # Move tensors to GPU/CPU

            batch = batch.to(device)
            labels = labels.to(device)

            # Forward pass
            outputs = model.forward(batch)
            loss = criterion(outputs, labels)

            # Backward and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()


        with torch.no_grad():
            val_batch,val_labels = next(iter(val_loader))
            val_batch = val_batch.to(device)
            val_labels_pred = model.forward(val_batch).cpu().numpy()

            performance = eval_func(val_labels.cpu().numpy(),val_labels_pred)


        #Check if we have an increase in performance
        if(performance > best_performance):
            rounds = 0
            best_performance = performance
            best_prediction = val_labels_pred
            best_state_dict = model.state_dict()
            result_dict["train_loader"] = train_loader
            result_dict["val_batch"] = val_batch
            result_dict["val_labels"] = val_labels
            result_dict["best_model"] = model
            result_dict["val_labels_pred"] = best_prediction
            result_dict["best_performance"] = best_performance
        else:
            rounds += 1
            if(rounds == early_stopping_rounds):
                stop = True


        #Print statement, every 5 epochs or if it is the last epoch
        if(((epoch%5==0)|(stop==True))&(verbose==True)):
            print("EPOCH:"+str(epoch))
            if(stop==True):
                print("Training to be concluded after this epoch") 
            print("Average training loss per sample  = "+str(train_loss))
            print('Performance of the network in current epoch = '+str(round(performance,4)))
            print('Best performance of the network yet  = '+str(round(best_performance,4)))


        epoch += 1
    #While loop ends

    print("BEST SCORE IS:"+str(best_performance))



    return(result_dict)

In [35]:

from sklearn.model_selection import train_test_split
X = X.reshape([n_data,1,D])
S = S.reshape([n_data,D_out])

X_ind = np.array(range(0,n_data))
S_ind = np.array(range(0,n_data))

X_train_ind,X_val_ind,S_train_ind,S_val_ind = train_test_split(X_ind,S_ind,test_size=0.25,shuffle=True,random_state=SEED)
X_train = X[X_train_ind,:,:]
print(X_train.shape)
X_val = X[X_val_ind,:,:]
print(X_val.shape)
S_train = S[S_train_ind,:]
S_val = S[S_val_ind,:]

train_dataset = torch.utils.data.TensorDataset(X_train,S_train)
val_dataset = torch.utils.data.TensorDataset(X_val,S_val)


from sklearn.metrics import mean_squared_error
def neg_mean_squared_error(y_true,y_pred):
  return((-1)*(mean_squared_error(y_true,y_pred)))

#Defining hyper parameters 
batch_size = 300
dropout = 0.2
learning_rate = 0.001
epochs=200
early_stopping_rounds = 10
criterion = nn.MSELoss()



torch.Size([1844, 1, 513])
torch.Size([615, 1, 513])


In [36]:
cnn1d = Network1d(
            K=K,
            H=H,
            D=D,
            K_out=K_out,
            H_out=H_out,
            D_out=D_out,
            dropout=0.2,
    )
result_dict = train_neural_network(cnn1d,
                                   train_dataset,
                                   val_dataset,
                                   epochs,
                                   early_stopping_rounds,
                                   batch_size,
                                   learning_rate,
                                   verbose=True,
                                   criterion=criterion,
                                   eval_func=neg_mean_squared_error,
                                   device=device)     

Training commenced
EPOCH:5
Average training loss per sample  = 0.4132792465388775
Performance of the network in current epoch = -0.0745
Best performance of the network yet  = -0.0745
EPOCH:10
Average training loss per sample  = 0.348116397857666
Performance of the network in current epoch = -0.0645
Best performance of the network yet  = -0.0645
EPOCH:15
Average training loss per sample  = 0.291540190577507
Performance of the network in current epoch = -0.0526
Best performance of the network yet  = -0.0526
EPOCH:20
Average training loss per sample  = 0.21804503723978996
Performance of the network in current epoch = -0.0422
Best performance of the network yet  = -0.0422
EPOCH:25
Average training loss per sample  = 0.18131311796605587
Performance of the network in current epoch = -0.0365
Best performance of the network yet  = -0.0363
EPOCH:30
Average training loss per sample  = 0.15034741163253784
Performance of the network in current epoch = -0.031
Best performance of the network yet  = 

In [103]:
def get_snr(s_actual,s_pred):
    num = np.dot(s_actual,s_actual)
    den = np.dot((s_actual-s_pred),(s_actual-s_pred))
    return(10*log(num/den,10))

def get_time_domain_signal(S,S_dir,length):
    S_complex = np.multiply(S,S_dir)
    return(istft(S_complex, hop_length=512, length = length))

def clean_signal(input_file,output_file,model,device):
    """
    params :
    input_file = name of input wav file to read
    output_file = name of output wav file to write
    model = pretrained model object for prediction
    device = device(pytorch) on which model is trained
    """
    x, sr=librosa.load(input_file, sr=None)
    X_complex=stft(x, n_fft=1024, hop_length=512)
    X = np.abs(X_complex.T)


    X_dir = np.divide(X_complex,X.T)
    X = torch.tensor(X).to(device)
    D = X.shape[1]
    n_data = X.shape[0]

    X = X.reshape([n_data,1,D])

    X_clean_pred = model.forward(X).detach().cpu().numpy()

    x_clean_pred = get_time_domain_signal(X_clean_pred.T,X_dir,length=len(x))
    librosa.output.write_wav(output_file,x_clean_pred, sr)
    return

In [100]:
S_pred = cnn1d.forward(X).detach().cpu().numpy()
s_pred = get_time_domain_signal(S_pred.T,X_dir,length=len(x))
print("Mean Squared Error = "+str(mean_squared_error(S.cpu().numpy(),S_pred)))
print("SNR = "+str(get_snr(s,s_pred)))

Mean Squared Error = 0.016903779
SNR = 8.616533193464864


In [104]:
clean_signal("test_x_01.wav","test_x_01_pred.wav",cnn1d,device)
clean_signal("test_x_01.wav","test_x_01_pred.wav",cnn1d,device)
clean_signal("test_x_02.wav","test_x_02_pred.wav",cnn1d,device)