In [None]:
from google.colab import files
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%capture
import os, sys
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
!pip install torch torchaudio torchvision
!pip install livelossplot
!ln -sf /opt/bin/nvidia-smi /usr/bin/nvidia-smi
!pip install -U nvgpu

import torch
import torchaudio
import matplotlib.pyplot as plt
import random
import torch.nn as nn
import torch.nn.functional as F
from livelossplot import PlotLosses
import math
import numpy as np
import torchvision
from IPython.display import clear_output
from IPython.display import Audio
import librosa
import os
import pynvml

torch.set_printoptions(precision=7)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')



In [None]:
# a function to track the the available GPU memory
def memory():
  pynvml.nvmlInit()
  GPU_ID = 0
  handle = pynvml.nvmlDeviceGetHandleByIndex(GPU_ID)
  meminfo = pynvml.nvmlDeviceGetMemoryInfo(handle)
  MB_SIZE = 1024*1024
  print('available: ',round(meminfo.total/MB_SIZE)) # 6078 MB
  print('used:      ',round(meminfo.used/MB_SIZE))  # 531 MB
  print('free:      ',round(meminfo.free/MB_SIZE))  # 5546 MB

  pynvml.nvmlShutdown()

In [None]:
# load the normalized audio clips into tensors

bongo_dir = "drive/My Drive/Colab Notebooks/Bachata/1 bar/bongo/normalized"
guira_dir = "drive/My Drive/Colab Notebooks/Bachata/1 bar/guira/normalized"
bass_dir = "drive/My Drive/Colab Notebooks/Bachata/1 bar/bass/normalized"
guitar_dir = "drive/My Drive/Colab Notebooks/Bachata/1 bar/guitar/normalized"
directories = [bongo_dir,guira_dir,bass_dir,guitar_dir]
instruments = ['bongo','guira','bass','guitar']
  

total_list=[]
for directory in directories:
  mylist = []
  for filename in os.listdir(directory):
    if filename.endswith(".aiff"):
      myfile = torchaudio.load("{0}/{1}".format(directory,filename))[0][0]
      if len(myfile) == 36000:          #splitting the 1 bar loops into 2 half bar loops
        mylist.append(myfile[:18000])
        mylist.append(myfile[18000:])
  total_list.append(mylist)
bongo_list = torch.stack(total_list[0])
guira_list = torch.stack(total_list[1])
bass_list = torch.stack(total_list[2])
guitar_list = torch.stack(total_list[3])
instruments_wave = [bongo_list,guira_list,bass_list,guitar_list]

In [None]:
# obtaining the magnitude and phase matrices via short-time Fourier transform
def _stft(waveform):
    waveform = waveform.squeeze(0)
    spec = torch.stft(waveform,n_fft=510,hop_length=70,win_length=510,window=torch.hann_window(510))  #short time Fourier transform to produce a complex matrix
    mag = (spec[...,0] * spec[...,0] + spec[...,1] * spec[...,1])**0.5 # piecewise modulus of the complex stft matrix to obtain the amplitude/magnitude matrix
    phase = torch.atan2(spec[...,1], spec[...,0]) # calculating the piecewise angle of the complex matrix, this is the same as np.angle
    mag = mag[...,1:-1]       # resulting shape is 256x258 --> cut down the edges to have 256x256 square shaped matrices
    phase = phase[...,1:-1]

    
    return mag, phase



In [None]:
# heatmap calculation of the magnitude to obtain visible spectrograms
def heatmap(mag):
    mag = np.log10(mag+1)     # bringing the magnitude to the visible spectrum
    for i in range(len(mag)):         #normalizing the values of the magnitude within the batch
      mag[i] = mag[i]/np.max(mag[i])
    mag = mag.astype(np.float32)    
    mag=torch.from_numpy(mag.copy()).unsqueeze(1)

    return mag



In [None]:
#get_batch function used to produce the input spectrograms for the network from the audio clips
def get_batch(batch_size,N_=1,instruments_wave=instruments_wave,instrument='random'):

    index_list = [[0,1],[1,3],[2,1],[3,1/3]]  # 1st index represents the intrument (bongo,guira,bass,guitar), 2nd index their 'weights' which we use to multiply them to make them sound louder or queieter
    instruments = ['bongo','guira','bass','guitar']

    #selecting the target instrument. it could be random or specified
    if instrument=='random' or instrument == "all": 
      index_pick = random.choice(index_list)
    elif instrument=='bongo':
      index_pick = index_list[0]
    elif instrument=='guira':
      index_pick = index_list[1]
    elif instrument=='bass':
      index_pick = index_list[2]
    elif instrument=='guitar':
      index_pick = index_list[3]

    index = index_pick[0]
    index_scale = index_pick[1]

    #selecting 2 distinct instances of the target instrument: one target and one true
    x_target = torch.stack(random.choices(instruments_wave[index],k=batch_size))
    x_true = torch.stack(random.choices(instruments_wave[index],k=batch_size))
    for i in range(batch_size):           
      while torch.equal(x_true[i],x_target[i]):
        x_true[i] = torch.stack(random.choices(instruments_wave[index],k=1))

    #remove the already selected instrument so that it won't appear twice in the mixture
    index_list.remove(index_pick)
   
    #choosing the rest of the instruments for the mixture at random
    audio_target = (x_target) * index_scale     
    audio0 = (x_true) * index_scale     
    audio1 = (torch.stack(random.choices(instruments_wave[index_list[0][0]],k=batch_size))) * index_list[0][1]
    audio2 = (torch.stack(random.choices(instruments_wave[index_list[1][0]],k=batch_size))) * index_list[1][1]
    audio3 = (torch.stack(random.choices(instruments_wave[index_list[2][0]],k=batch_size))) * index_list[2][1]
    
    audios = [audio0, audio1, audio2, audio3] 
    
    #adding the audios together to create the mixture
    wave_mix = audio0 + audio1 + audio2 + audio3

    #deriving the magnitude (or amplitude) and phase matrices of the target, the ground truth and the mixture audios
    #and producing the spectrograms from the amplitude matrices using the heatmap
    amp_mix, phase_mix = _stft(wave_mix)
    spec_mix = heatmap(amp_mix.numpy())

    amp_true,phase_true = _stft(audio0)
    spec_true= heatmap(amp_true.numpy())

    amp_target,phase_target = _stft(audio_target)
    spec_target = heatmap(amp_target.numpy())

    #this section is only used when extracting all 4 instruments from the mixture during testing
    if instrument == 'all':
      target_0 = torch.stack(random.choices(instruments_wave[index],k=batch_size))
      target_1 = torch.stack(random.choices(instruments_wave[index_list[0][0]],k=batch_size))
      target_2 = torch.stack(random.choices(instruments_wave[index_list[1][0]],k=batch_size))
      target_3 = torch.stack(random.choices(instruments_wave[index_list[2][0]],k=batch_size))
      
      #creating a target for each instrument, note that no ground truth needed at this point
      for i in range(batch_size):
        while torch.equal(audio0[i],target_0[i]):
          target_0[i] = torch.stack(random.choices(instruments_wave[index],k=1))
        while torch.equal(audio1[i],target_1[i]):
          target_1[i] = torch.stack(random.choices(instruments_wave[index_list[0][0]],k=1))
        while torch.equal(audio2[i],target_2[i]):
          target_2[i] = torch.stack(random.choices(instruments_wave[index_list[1][0]],k=1))
        while torch.equal(audio3[i],target_3[i]):
          target_3[i] = torch.stack(random.choices(instruments_wave[index_list[2][0]],k=1))
      
      #producing the spectrograms for each target instances
      amp_target0,_ = _stft(target_0)
      spec_target0 = heatmap(amp_target0.numpy())
      amp_target1,_ = _stft(target_1)
      spec_target1 = heatmap(amp_target1.numpy())
      amp_target2,_ = _stft(target_2)
      spec_target2 = heatmap(amp_target2.numpy())
      amp_target3,_ = _stft(target_3)
      spec_target3 = heatmap(amp_target3.numpy())

      #the order at which the targets were picked
      instrument_list=[instruments[index],instruments[index_list[0][0]],instruments[index_list[1][0]],instruments[index_list[2][0]]]
      #use this only during testing
      return spec_mix,phase_mix,wave_mix,audio0,spec_target0,audio1,spec_target1,audio2,spec_target2,audio3,spec_target3,instrument_list

    #use this during training
    return phase_mix, phase_true, phase_target, wave_mix, spec_mix, spec_true, spec_target, amp_mix, instruments[index], audio0

bsize = 4
phase_mix,phase_true, phase_target, audio_mix, spec_mix, spec_true, spec_target, ori_amp_mix, instrument,true_original  = get_batch(bsize)



In [None]:
#demonstrating a random spectrogram of each instrument
specmix,_,_,_,spec_target0,_,spec_target1,_,spec_target2,_,spec_target3,inst_list=get_batch(bsize,instrument='all')
fig=plt.figure(figsize=(20,5))
fig.add_subplot(1,5,1).set_title('mix',fontsize=18)
plt.imshow(specmix[0][0])
fig.add_subplot(1,5,2).set_title(inst_list[0],fontsize=18)
plt.imshow(spec_target0[0][0])
fig.add_subplot(1,5,3).set_title(inst_list[1],fontsize=18)
plt.imshow(spec_target1[0][0])
fig.add_subplot(1,5,4).set_title(inst_list[2],fontsize=18)
plt.imshow(spec_target2[0][0])
fig.add_subplot(1,5,5).set_title(inst_list[3],fontsize=18)
plt.imshow(spec_target3[0][0])
plt.show()

In [None]:
#demonstrating the target, ground truth and mixture 
fig = plt.figure(figsize=(3*bsize,5*2))
for b in range(bsize):
  plt.subplot(3,bsize,b+1).set_title('spec mix {0}'.format(b))
  plt.imshow(spec_mix[b,0].numpy())

  plt.subplot(3,bsize,b+1+bsize).set_title('spec target {0}'.format(b))
  plt.imshow(spec_target[b,0].numpy())

  plt.subplot(3,bsize,b+1+2*bsize).set_title('spec true {0}'.format(b))
  plt.imshow(spec_true[b,0].numpy())
plt.tight_layout()


In [None]:
# defining our U-Net architcture
class UNet(nn.Module):
    def __init__(self, in_class=2, out_class=1, test = -1):     #two in class: mixture,target; one out class: prediction
        super().__init__()
        self.instrument_index = test
                
        self.dconv_down1 = self.conv_block(in_class, 64)
        self.dconv_down2 = self.conv_block(64, 128)
        self.dconv_down3 = self.conv_block(128, 256)
        self.dconv_down4 = self.conv_block(256, 512)
        self.dconv_down5 = self.conv_block(512, 1024)

        self.dconv_up4 = self.conv_block(512 + 1024, 512)
        self.dconv_up3 = self.conv_block(256 + 512, 256)
        self.dconv_up2 = self.conv_block(128 + 256, 128)
        self.dconv_up1 = self.conv_block(64 + 128, 64)
        
        self.conv_last = nn.Conv2d(64, out_class, 1)


    def conv_block(self, in_channels, out_channels):                #convolutional block we used in down and upsampling
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
    )
        
    def forward(self, x):
        z = self.encode(x)
        out = self.decode(z)
        return out

    def encode(self,x):             #downsampling
        self.conv1 = self.dconv_down1(x) # skip-connection 1
        x = F.max_pool2d(self.conv1, 2)

        self.conv2 = self.dconv_down2(x) # skip-connection 2
        x = F.max_pool2d(self.conv2, 2)

        self.conv3 = self.dconv_down3(x) # skip-connection 3
        x = F.max_pool2d(self.conv3, 2)

        self.conv4 = self.dconv_down4(x) # skip-connection 4
        x = F.max_pool2d(self.conv4, 2)

        x = self.dconv_down5(x)

        return x

    def decode(self,x):             #upsampling
        x = F.interpolate(x, scale_factor=2, mode='bilinear')    
        x = torch.cat([x, self.conv4], dim=1)    # skip-connection 4
        x = self.dconv_up4(x)

        x = F.interpolate(x, scale_factor=2, mode='bilinear')    
        x = torch.cat([x, self.conv3], dim=1)    # skip-connection 3
        x = self.dconv_up3(x)

        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat([x, self.conv2], dim=1)    # skip-connection 2
        x = self.dconv_up2(x)

        x = F.interpolate(x, scale_factor=2, mode='bilinear')
        x = torch.cat([x, self.conv1], dim=1)    # skip-connection 1
        x = self.dconv_up1(x)

        x = self.conv_last(x)
        
        return x

N = UNet().to(device)

print(f'> Number of network parameters {len(torch.nn.utils.parameters_to_vector(N.parameters()))}')

# initialise the optimiser
optimiser = torch.optim.Adam(N.parameters(), lr=0.001)
epoch = 0
liveplot = PlotLosses()

In [None]:
#load previously saved trained models
def loadmodel(model,path):
  checkpoint = torch.load(path)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimiser.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  loss = checkpoint['loss']
  loss_list = checkpoint['loss_list']
  
  model.train()
  return model,optimiser,epoch,loss,loss_list

In [None]:
#defining path for model to be saved during training
model_save_name = 'test_model.pt'
path = "drive/My Drive/Colab Notebooks/{0}".format(model_save_name)
!ls drive/My\ Drive/Colab\ Notebooks/*pt

In [None]:
#model training
loss_list_=[]
save_bool = True    #indicate to save model whilst training at every epoch
K = 3 #number of prediction samples to show at each epoch
while epoch<2:
  print('epoch',epoch)
  logs = {}
  train_loss_arr = np.zeros(0)
  for i in range(100):
    phase_mix,phase_true, phase_target,_, x_mix,x_true,x_target,_,instrument,_ = get_batch(16) 
    x_target,x_true,x_mix = x_target.to(device), x_true.to(device), x_mix.to(device) 
    optimiser.zero_grad()
    cnn_in = torch.cat([x_target,x_mix], dim=1).to(device)    #constructing the input for the network of size [Bx2x256x256]

    p = N(cnn_in)             #prediction
    p = torch.clamp(p,0,1)    #clipping values between 0 and 1

    loss = ((p-x_true)**2).mean()     #mean squared error loss between predicted and ground truth spectrograms
    loss.backward()
    optimiser.step()
    # torch.cuda.empty_cache()
  
  #prediction demonstration at every epoch 
  clear_output(wait=True)
  fig = plt.figure(figsize=(15,6))
  fig.add_subplot(2,2,1).set_title('mix epoch {0}'.format(epoch))
  plt.imshow(torchvision.utils.make_grid(x_mix[:K]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))  
  fig.add_subplot(2,2,3).set_title('target {0}'.format(instrument))
  plt.imshow(torchvision.utils.make_grid(x_target[:K]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))
  fig.add_subplot(2,2,2).set_title('true {0}'.format(instrument))
  plt.imshow(torchvision.utils.make_grid(x_true[:K]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))
  fig.add_subplot(2,2,4).set_title('predicted {0}'.format(instrument))
  plt.imshow(torchvision.utils.make_grid(p[:K]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))  
  

  train_loss_arr = np.append(train_loss_arr, loss.cpu().data)
  liveplot.update({
        'loss': train_loss_arr.mean()
  })
  loss_list_.append(train_loss_arr)   #collecting the losses to save them

  liveplot.draw()     #plotting the loss
  plt.tight_layout()
  plt.show()
  epoch = epoch+1
  
  #saving the state of the model at every epoch
  if save_bool:
    torch.save({
              'epoch': epoch,
              'model_state_dict': N.state_dict(),
              'optimizer_state_dict': optimiser.state_dict(),
              'loss': loss,
              'loss_list': loss_list_,
              }, path)

In [None]:
#function to plot the average loss over every x epoch
def plotmeanloss(loss_list,epoch_range_,fname=None):
  max_epoch = len(loss_list)
  epoch_range = [x for x in range(max_epoch) if x % epoch_range_ == 0]
  y=[]
  x=[]
  y_err=[]
  for i in range(len(epoch_range)):
    try:      #check if total epochs mod x is 0
      y_range = loss_list[epoch_range[i]:epoch_range[i+1]]
      y.append(np.mean(y_range))  
      y_err.append(np.std(y_range))     #error calculation
      x.append(epoch_range[i]+epoch_range_)
    except:           #if the total number of epochs mod x is not 0, then take the average over the remaining epochs at the end
      y_range = loss_list[epoch_range[i]:]
      x.append(max_epoch)
      y.append(np.mean(y_range))        
      y_err.append(np.std(y_range))      #error calculation
  fig = plt.figure(figsize=(6,5))
  if fname!=None:
    fig.suptitle('{0}'.format(fname))
  plt.title('Loss over every {0} epochs'.format(epoch_range_))
  if epoch_range_ ==1:    #if x==1, then don't plot error bars
    plt.plot(x,y)
  else:                   # if x>1, do an errorbar plot with the calculated errors
    plt.errorbar(x,y,y_err,0,fmt='o-',mew=2,ms=7,ecolor='gray',capsize=5,elinewidth=2)
  plt.xlabel('epoch')



In [None]:
#reversing the previously seen heatmap calculations to obtain the predicted magnitude matrix
def inverse_heatmap(heatmap_):
  mag1 = heatmap_.squeeze(1).numpy()
  mag2 = 10**mag1             #reversing the log_10
  mag3 = mag2-1
  mag = torch.from_numpy(mag3)
  return mag


In [None]:
#reversing the magnitude and phase matrix calculations to produce an input for the inverse STFT
def istft_reconstruction(mag, phase):

  # real and imaginary parts from the angles
  phase_real = torch.cos(phase)     
  phase_imag = torch.sin(phase)

  #constructing the complex matrix input for the istft from the magnitude and the real and imaginary parts obtained from the phase
  complex_matrix = torch.stack((mag*phase_real,mag*phase_imag),dim=-1)

  #producing the predicted audio from the complex matrix by applying inverse STFT calculations
  wav = torchaudio.functional.istft(complex_matrix,n_fft=510,hop_length=70,win_length=510,window=torch.hann_window(510),length=18000) #fixing the sample length to be 18000 preserved the audio duration
  return wav

In [None]:
#choose which trained model to load
modeldict = 'drive/My Drive/Colab Notebooks/'
modellist=[]
for filename in os.listdir(modeldict):
  if filename.endswith('.pt'):
    modellist.append(filename)
import ipywidgets as widgets
model_picker = widgets.Dropdown(options=modellist)
model_picker


In [None]:
#define the network and load the selected model into it
N = UNet().to(device)
N,_,_,_,loss_list = loadmodel(N,'{0}{1}'.format(modeldict,model_picker.value))
print(model_picker.value)

In [None]:
plotmeanloss(loss_list,1)

In [None]:
#testing the trained model
def test(model,myinstrument,bsize,show_audio=True):

  if myinstrument == "all":     #extracting all instruments from the test clip
    prediction_list=[]
    if bsize == 1:    # minimum batch size of 2 is required for the get_batch to work, however later on we just going to consider the 1st batch and ignore the second one
      my_x_mix, my_phase_mix,audio_mix, true0, target0, true1, target1, true2, target2, true3, target3, instrument_list = get_batch(2,instrument=myinstrument)
    else:   #load mix, target and ground truth
      my_x_mix, my_phase_mix,audio_mix, true0, target0, true1, target1, true2, target2, true3, target3, instrument_list = get_batch(bsize,instrument=myinstrument)
    target_list=[target0,target1,target2,target3]
    true_list = [true0,true1,true2,true3]

    # for b in range(bsize):      #display the mixture audio
    #   print('     MIX example {0}'.format(b))
    #   display(Audio(audio_mix[b],rate=11025))

    for target in range(4):   #for each instrument, take the correspoinding spectrograms and feed into the network
      print('\nINSTRUMENT {0}\n'.format(instrument_list[target]))
      cnn_in = torch.cat([target_list[target],my_x_mix], dim=1).to(device)  
      prediction = model(cnn_in)
      prediction = torch.clamp(prediction,0,1)  #as before, clamp predictions between 0 and 1
      amp_inverse_prediction = inverse_heatmap(prediction.detach().cpu())     #predicted magnitude
      wav_prediction = istft_reconstruction(amp_inverse_prediction,my_phase_mix)      #predicted audio
      wav_prediction = wav_prediction.cpu().numpy()
      prediction_list.append(wav_prediction)

      for b in range(bsize):      #display the mix and the predicted audio together with the correspongind ground truth audio for easy comparison
        print('example {0}'.format(b))
        # print(true_list[target][b].shape)
        print('MIX'.format(b))
        display(Audio(audio_mix[b],rate=11025))
        print('prediction')
        display(Audio(wav_prediction[b],rate=11025))
        print('true')
        display(Audio(true_list[target][b],rate=11025))
    return prediction_list,true_list,audio_mix,instrument_list

  else:     #if a target instrument is specified, only fetch a target spectrogram of that instrument
    my_phase_mix,my_phase_true, my_phase_target,_, my_x_mix,my_x_true,my_x_target,_,my_instrument,true_original = get_batch(bsize,instrument=myinstrument)
  
    cnn_in = torch.cat([my_x_target,my_x_mix], dim=1).to(device)    # as before, input to our cnn
    prediction = model(cnn_in)
    prediction = torch.clamp(prediction,0,1)    #predictions clamped as before

    #plot the mixture, the target, the ground truth, and the predicted spectrograms together for easy comparison
    fig = plt.figure(figsize=(15,8))
    fig.add_subplot(3,2,1).set_title('mix')
    plt.imshow(torchvision.utils.make_grid(my_x_mix[:]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))  
    fig.add_subplot(3,2,3).set_title('target {0}'.format(my_instrument))
    plt.imshow(torchvision.utils.make_grid(my_x_target[:]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))
    fig.add_subplot(3,2,2).set_title('true {0}'.format(my_instrument))
    plt.imshow(torchvision.utils.make_grid(my_x_true[:]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))
    fig.add_subplot(3,2,4).set_title('predicted {0}'.format(my_instrument))
    plt.imshow(torchvision.utils.make_grid(prediction[:]).cpu().data.permute(0,2,1).contiguous().permute(2,1,0))
    plt.tight_layout()
    plt.show()

    #producing the magnitudes and audio of the predicted, ground truth, and mixture spectrograms
    amp_inverse_prediction = inverse_heatmap(prediction.detach().cpu())
    wav_prediction = istft_reconstruction(amp_inverse_prediction,my_phase_mix)
    wav_prediction = wav_prediction.cpu().numpy()

    amp_inverse_true = inverse_heatmap(my_x_true.detach().cpu())
    wav_true = istft_reconstruction(amp_inverse_true,my_phase_true)
    wav_true = wav_true.cpu().numpy()

    amp_inverse_mix = inverse_heatmap(my_x_mix.detach().cpu())
    wav_mix = istft_reconstruction(amp_inverse_mix,my_phase_mix)
    wav_mix = wav_mix.cpu().numpy()

    print(my_instrument,'\n')
    #once again display the audio clips together within a batch for easy comparison
    if show_audio:
      for b in range(bsize):
        print('batch {0}\n'.format(b+1))
        print('     PREDICTED')
        display(Audio(wav_prediction[b],rate=11025))
        print('     TRUE')
        display(Audio(wav_true[b],rate=11025))
        print('     MIX')
        display(Audio(wav_mix[b],rate=11025))
        print('\n\n')
    return wav_prediction,wav_true,wav_mix

In [None]:
#choosing which instrument to filter and how many samples(batch size) we want

instrument_to_extract = "all" #@param ["random","bongo", "guira", "bass", "guitar",'all']

batch_size =  5#@param {type:"integer"}

preds,trues,mix, inst_list = test(N,instrument_to_extract,batch_size,show_audio=True)


In [None]:
# this stretching function was obtained from https://github.com/gaganbahga/time_stretch
def stretch(x, factor,length=0, nfft=2048):
    '''
    stretch an audio sequence by a factor using FFT of size nfft converting to frequency domain
    :param x: np.ndarray, audio array in PCM float32 format
    :param factor: float, stretching or shrinking factor, depending on if its > or < 1 respectively
    :return: np.ndarray, time stretched audio
    '''
    stft = librosa.core.stft(x, n_fft=nfft).transpose()  # i prefer time-major fashion, so transpose
    stft_rows = stft.shape[0]
    stft_cols = stft.shape[1]

    times = np.arange(0, stft.shape[0], factor)  # times at which new FFT to be calculated
    hop = nfft/4                                 # frame shift
    stft_new = np.zeros((len(times), stft_cols), dtype=np.complex_)
    phase_adv = (2 * np.pi * hop * np.arange(0, stft_cols))/ nfft
    phase = np.angle(stft[0])

    stft = np.concatenate( (stft, np.zeros((1, stft_cols))), axis=0)

    for i, time in enumerate(times):
        left_frame = int(np.floor(time))
        local_frames = stft[[left_frame, left_frame + 1], :]
        right_wt = time - np.floor(time)                        # weight on right frame out of 2
        local_mag = (1 - right_wt) * np.absolute(local_frames[0, :]) + right_wt * np.absolute(local_frames[1, :])
        local_dphi = np.angle(local_frames[1, :]) - np.angle(local_frames[0, :]) - phase_adv
        local_dphi = local_dphi - 2 * np.pi * np.floor(local_dphi/(2 * np.pi))
        stft_new[i, :] =  local_mag * np.exp(phase*1j)
        phase += local_dphi + phase_adv

 
    out = librosa.core.istft(stft_new.transpose())
    return out, x.shape[0]


In [None]:
def extraction(model,mysong,instrument):
  '''
  extract "instrument" from "mysong" using the "model" neural network
  '''

  tempo_ori,beats = librosa.beat.beat_track(mysong.numpy(),sr=44100,start_bpm=100.0,units="samples",trim=True,tightness=100)  #identify beat locations
  out = []
  i=0
  bar = 4     # we are interested in 4-beat-long clips
  while i+bar < len(beats):
    # while test_instrument != instrument:
    _,_, _,_, _,_,test_x_target,_,test_instrument,_ = get_batch(2,instrument=instrument)    #fetch a target instrument clip from training data
    test_x_target = test_x_target[0].unsqueeze(0)
    
    song_sample = mysong[beats[i]:beats[i+bar]]   #generate clips from the song every 4 beats
    sample_resampled = torchaudio.transforms.Resample(44100, new_freq=11025)(song_sample.view(1,-1))[0]   #resample clips to have the same sample rate as the training data

    sample_length = sample_resampled.shape[0]   #sample length of the resampled clip
   
    shrinkrate = sample_length/18000      #stretching factor
    
    nfft = 1000 #window size of the fast Fourier transform used in stretching
    new_sample, inverse_shape = stretch(sample_resampled.numpy(),shrinkrate,nfft=nfft)  #stretch the clip so that it has the same tempo and hence sample length as the training clips. also save the length of the original clip
    new_sample_tensor = torch.from_numpy(new_sample).unsqueeze(0)
    sample_amp, sample_phase = _stft(new_sample_tensor)       #produce the magnitude and phase matrices of size 256x256 of the clip 
    sample_heatmap = heatmap(sample_amp.unsqueeze(0).numpy())   #compute spectrogram
    sample_cnn_in = torch.cat([test_x_target,sample_heatmap], dim=1).to(device) #input of the model using the spectrogram of the clip from the song and the target one from the training dataset
    sample_p = model(sample_cnn_in)     #predicted spectrogram
    sample_p = torch.clamp(sample_p,0,1)    #clamping prediction

    amp_inverse = inverse_heatmap(sample_p.detach().cpu())      #retrieve predicted magnitude
    wav = istft_reconstruction(amp_inverse,sample_phase)        #construct predicted audio
    wav = wav.cpu().numpy()[0]
    wav,_ = stretch(wav,1/shrinkrate,nfft = 512, length=inverse_shape)    #reverse the stretching of the audio / "unstretch" it using the previously saved original sample length of the clip, so that it has the same tempo as the original song
    out.extend(wav)       #collect and concatenate the predictions 
    i+=bar

  return np.array(out)

In [None]:
#pick a song to process
songs_dir = "drive/My Drive/Colab Notebooks/Bachata/songs/"
songlist=[]
songnames=[]
songdir=sorted([x for x in os.listdir(songs_dir) if x.endswith('.mp3')])
for name in songdir:
    mysong_ori, sample_rate = torchaudio.load("{0}{1}".format(songs_dir,name))
    mysong = mysong_ori[0]
    songlist.append(mysong)
    songnames.append(name)

song_picker = widgets.Dropdown(options=songnames)
song_picker

In [None]:
def originalsong(name,songlist,songnames):
  '''
  displaying original song
  '''
  print("original song",name)
  song = songlist[songnames.index(name)]
  display(Audio(song,rate=44100))

In [None]:
#display original song
originalsong(song_picker.value,songlist,songnames)

In [None]:
#extract all 4 instruments
print(model_picker.value)
print(song_picker.value)
instrument_list = ["bongo", "guira", "bass", "guitar"]
song = songlist[songnames.index(song_picker.value)]
for i in instrument_list:
  wav = extraction(N,song,i)
  print("\nextracted",i)
  display(Audio(wav,rate=11025))



In [None]:
#extract a specific instrument
instrument = "guitar" #@param ["bongo", "guira", "bass", "guitar"]
song = songlist[songnames.index(song_picker.value)]
wav = extraction(N,song,instrument)
print("extracted",instrument)
display(Audio(wav,rate=11025))