In [None]:
!python --version
!pip install torch  --index-url https://download.pytorch.org/whl/cu118
!pip install matplotlib
!pip install soundfile
!pip install librosa audioread
!pip install mir_eval

In [None]:
import os
import re
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import librosa
import librosa.display
import torch.nn.functional as F
import datetime
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import mir_eval
from torch.optim.lr_scheduler import ReduceLROnPlateau

import gc
torch.cuda.empty_cache()
gc.collect()

device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

In [None]:
test_path = "/kaggle/input/musdb18-hq/test"
train_path = "/kaggle/input/musdb18-hq/train"

files = {"train_x":[], "train_y":[], "test_x":[], "test_y":[]}

print("Processing specto for "+train_path)
for root, dirs, _ in os.walk(train_path):
    dirs.sort()
    for d in dirs :
        for r2,d2,f2 in os.walk(os.path.join(root, d)):
            f2.sort()
            files["train_x"]+=[[os.path.join(r2, "mixture.wav")]]
            files["train_y"]+=[[os.path.join(r2, name) for name in f2 if "mixture" not in name]]
                

print("Processing specto for "+test_path)
for root, dirs, _ in os.walk(test_path):
    dirs.sort()
    for d in dirs :
        for r2,d2,f2 in os.walk(os.path.join(root, d)):
            f2.sort()
            files["test_x"]+=[[os.path.join(r2, "mixture.wav")]]
            files["test_y"]+=[[os.path.join(r2, name) for name in f2 if "mixture" not in name]]
#print("all spectos saved")
#print((max_freq,max_tframe))
files["train_x"]+=files["test_x"][:(len(files["test_x"])//3)]
files["train_y"]+=files["test_y"][:(len(files["test_y"])//3)]
files["test_x"]=files["test_y"][(len(files["test_x"])//3):]
files["test_y"]=files["test_y"][(len(files["test_y"])//3):]
for f in files.keys() :
  print(f)
  print(len(files[f]))

In [None]:
freq_bins = 1024
tframe = 2048
class SpectroDataset(Dataset) :
    def __init__(self, train_x, train_y):
        self.train_x = train_x
        self.train_y = train_y

    def __len__(self):
        return len(self.train_x)

    def __getitem__(self, idx):
        #print(self.train_x[idx])
        #print(self.train_y[idx])
        audio, _ = librosa.load(self.train_x[idx][0])
        Sx = np.abs(librosa.stft(audio))
        magnitude, _ = librosa.magphase(Sx)
        mixture = torch.from_numpy(magnitude[:freq_bins,:tframe])
        mixture = mixture.to(device)
        to_pad = tframe-mixture.shape[1]
        if to_pad>0 :
          if to_pad%2==0 :
              mixture = F.pad(mixture,(to_pad//2,to_pad//2),mode='constant',value=0)
          else :
              mixture = F.pad(mixture,(to_pad//2,to_pad//2+1),mode='constant',value=0)
        stems = dict()
        for s in self.train_y[idx] :
            audio, _ = librosa.load(s)
            Sy = np.abs(librosa.stft(audio))
            magnitude, _ = librosa.magphase(Sy)
            stem = torch.from_numpy(magnitude[:freq_bins,:tframe])
            stem = stem.to(device)
            to_pad = tframe-stem.shape[1]
            if to_pad>0 :
              if to_pad%2==0 :
                  stem = F.pad(stem,(to_pad//2,to_pad//2),mode='constant',value=0)
              else :
                  stem = F.pad(stem,(to_pad//2,to_pad//2+1),mode='constant',value=0)
            if "bass" in s :
                stems["bass"] = stem
            elif "drums" in s :
                stems["drums"] = stem
            elif "other" in s :
                stems["other"] = stem
            elif "vocal" in s :
                stems["vocal"] = stem
        return mixture, stems

In [None]:
class UNet(nn.Module) :
    def __init__(self):
        super().__init__()

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

        # Encoder
        self.e11 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.e12 = nn.Conv2d(16, 16, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e21 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.e22 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.e31 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.e32 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        #Bottleneck
        self.e51 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.e52 = nn.Conv2d(128, 128, kernel_size=3, padding=1)


        # Decoder
        self.upconv2 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.d2 = nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.upconv3 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.d3 = nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(32)

        self.upconv4 = nn.ConvTranspose2d(32, 16, kernel_size=2, stride=2)
        self.d4 = nn.ConvTranspose2d(32, 16, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(16)

        # Output layer
        self.outconv = nn.ConvTranspose2d(16, 4, kernel_size=1)
        self.dropout = nn.Dropout(p=0.4)


    def forward_down(self,x) :
        # Encoder
        xe11 = self.relu(self.e11(x))
        xe12 = self.relu(self.e12(xe11))
        xp1 = self.pool1(xe12)

        xe21 = self.relu(self.e21(xp1))
        xe22 = self.relu(self.e22(xe21))
        xp2 = self.pool2(xe22)

        xe31 = self.relu(self.e31(xp2))
        xe32 = self.relu(self.e32(xe31))
        xp3 = self.pool3(xe32)

        #Bottleneck
        xe51 = self.relu(self.e51(xp3))
        xe52 = self.relu(self.e52(xe51))
        return xe52, [xe32,xe22,xe12]

    def forward_up(self,x,to_concats) :
        # Decoder
        xu2 = self.upconv2(x)
        xu21 = torch.cat((xu2, to_concats[0]), dim=1)
        xd2 = self.dropout(self.bn2(self.relu(self.d2(xu21))))

        xu3 = self.upconv3(xd2)
        xu31 = torch.cat((xu3, to_concats[1]), dim=1)
        xd3 = self.dropout(self.bn3(self.relu(self.d3(xu31))))

        xu4 = self.upconv4(xd3)
        xu41 = torch.cat((xu4, to_concats[2]), dim=1)
        xd4 = self.dropout(self.bn4(self.relu(self.d4(xu41))))

        # Output
        out = self.outconv(xd4)
        out = self.dropout(out)
        out = self.softmax(out)
        return out

    def forward(self, input):
        x = input.unsqueeze(0)
        x, to_concats = self.forward_down(x)

        masks = self.forward_up(x,to_concats)
        vocal_stem = masks[:,0,:,:] * input
        bass_stem = masks[:,1,:,:] * input
        drums_stem = masks[:,2,:,:] * input
        other_stem = masks[:,3,:,:] * input

        stems = {"vocal_stem":vocal_stem.squeeze(1),
                 "bass_stem":bass_stem.squeeze(1),
                 "drums_stem":drums_stem.squeeze(1),
                 "other_stem":other_stem.squeeze(1)}
        return stems

In [None]:
class MultiTaskLoss(nn.Module) :
    def __init__(self, num_tasks):
        super(MultiTaskLoss, self).__init__()
        
        self.weights = nn.Parameter(torch.ones(num_tasks, requires_grad=True).to(device))
        self.loss_fn = nn.L1Loss()

    def forward(self, predictions, targets):
        total_loss = 0.0
        for task_idx in range(len(predictions)):
            task_loss = self.loss_fn(predictions[task_idx], targets[task_idx])
            total_loss += self.weights[task_idx] * task_loss

        return total_loss

In [None]:
model = UNet().to(device)
multi_task_loss = MultiTaskLoss(4).to(device)
loss_fn = nn.L1Loss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, patience=5)

'''
optimizer = optim.Adam([{'params': model.parameters()}, 
                        {'params': multi_task_loss.parameters()}],
                        lr=0.001, betas=(0.9, 0.999))


state_dict = torch.load(checkpoint_path)
print("loaded for resume training: "+best_checkpoint_path)
model.load_state_dict(state_dict['model_state_dict'])
multi_task_loss.load_state_dict(state_dict['loss_state_dict'])
optimizer.load_state_dict(state_dict['optimizer_sitate_dict'])
'''

timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
EPOCHS = 50

best_vloss = 1_000_000.
best_sdr = -1.0 
best_checkpoint_path = ""


train_dataloader = DataLoader(SpectroDataset(files["train_x"],files["train_y"]), batch_size=1, shuffle=True)
test_dataloader = DataLoader(SpectroDataset(files["test_x"],files["test_y"]), batch_size=1, shuffle=True)

In [None]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    training_losses = []
    avg_losses = []
    T,K=2.,4.
    for i, data in enumerate(train_dataloader):
        # Every data instance is an input + label pair
        mixtures, stems = data
        zr = np.zeros((freq_bins,tframe))
        if "bass" not in stems : 
            stems["bass"] = torch.from_numpy(zr).to(device)
        if "vocal" not in stems : 
            stems["vocal"] = torch.from_numpy(zr).to(device)
        if "drums" not in stems : 
            stems["drums"] = torch.from_numpy(zr).to(device)
        if "other" not in stems : 
            stems["other"] = torch.from_numpy(zr).to(device)
            
        #for s in stems.keys() : stems[s] = stems[s].unsqueeze(0) 
        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(mixtures)
        #print(outputs["bass_stem"].shape)
        #print(stems["bass"].shape)

        # Compute the loss and its gradients
        # Dynamic Weight Average strat
        
        y1 = torch.tensor(avg_losses[i-1][0] / avg_losses[i-2][0] if i>1 else 1.)
        y2 = torch.tensor(avg_losses[i-1][1] / avg_losses[i-2][1] if i>1 else 1.)
        y3 = torch.tensor(avg_losses[i-1][2] / avg_losses[i-2][2] if i>1 else 1.)
        y4 = torch.tensor(avg_losses[i-1][3] / avg_losses[i-2][3] if i>1 else 1.)
        y = torch.tensor([y1,y2,y3,y4])/T
        softmx = F.softmax(y,dim=0)
        zero_mask = softmx==0.0
        softmx[zero_mask] = 1e-9
        dwa = K * softmx
        loss1=loss_fn(outputs["vocal_stem"], stems["vocal"])
        loss2=loss_fn(outputs["bass_stem"], stems["bass"])
        loss3=loss_fn(outputs["drums_stem"], stems["drums"])
        loss4=loss_fn(outputs["other_stem"], stems["other"])
        loss=dwa[0]*loss1+dwa[1]*loss2+dwa[2]*loss3+dwa[3]*loss4
        weightless_loss=[loss1.item(),loss2.item(),loss3.item(),loss4.item()]
        training_losses+=[weightless_loss]
        current_avg = [sum(training_losses[x][k] for x in range(i+1))/(i+1) for k in range(4)]
        print("dwa: "+str(dwa))
        print('weightless_loss: '+str(weightless_loss))
        print("current_avg: "+ str(current_avg))
        current_avg = [x  if x!=0.0 else 1e-9 for x in current_avg]
        avg_losses += [current_avg]
        '''
        loss = multi_task_loss([outputs["vocal_stem"],outputs["bass_stem"],outputs["drums_stem"],outputs["other_stem"]], 
                               [stems["vocal"],stems["bass"],stems["drums"],stems["other"]])
        '''
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        print('training batch {}, loss: {}'.format(i + 1, loss.item()))

    return running_loss/(i+1)

In [None]:
avg_loss_history = []
for e in range(EPOCHS):
    epoch = e+1
    print('EPOCH {}:'.format(epoch))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch)
    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()
    running_sdr = 0.0
    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        testing_losses = []
        avg_losses = []
        T,K=2.,4.
        for i, data in enumerate(test_dataloader):
            mixtures, stems = data
            zr = np.zeros((freq_bins,tframe))
            if "bass" not in stems : 
                stems["bass"] = torch.from_numpy(zr).to(device)
            if "vocal" not in stems : 
                stems["vocal"] = torch.from_numpy(zr).to(device)
            if "drums" not in stems : 
                stems["drums"] = torch.from_numpy(zr).to(device)
            if "other" not in stems : 
                stems["other"] = torch.from_numpy(zr).to(device)
            #for s in stems.keys() : stems[s] = stems[s].unsqueeze(0) 
            outputs = model(mixtures)
        
            y1 = torch.tensor(avg_losses[i-1][0] / avg_losses[i-2][0] if i>1 else 1.)
            y2 = torch.tensor(avg_losses[i-1][1] / avg_losses[i-2][1] if i>1 else 1.)
            y3 = torch.tensor(avg_losses[i-1][2] / avg_losses[i-2][2] if i>1 else 1.)
            y4 = torch.tensor(avg_losses[i-1][3] / avg_losses[i-2][3] if i>1 else 1.)
            y = torch.tensor([y1,y2,y3,y4])/T
            softmx = F.softmax(y,dim=0)
            zero_mask = softmx==0.0
            softmx[zero_mask] = 1e-9
            print("softmx: "+str(softmx))
            dwa = K * softmx
            loss1=loss_fn(outputs["vocal_stem"], stems["vocal"])
            loss2=loss_fn(outputs["bass_stem"], stems["bass"])
            loss3=loss_fn(outputs["drums_stem"], stems["drums"])
            loss4=loss_fn(outputs["other_stem"], stems["other"])
            loss=dwa[0]*loss1+dwa[1]*loss2+dwa[2]*loss3+dwa[3]*loss4
            weightless_loss=[loss1.item(),loss2.item(),loss3.item(),loss4.item()]
            testing_losses+=[weightless_loss]
            current_avg = [sum(testing_losses[x][k] for x in range(i+1))/(i+1) for k in range(4)]
            print("dwa: "+str(dwa))
            print('weightless_loss: '+str(weightless_loss))
            print("current_avg: "+ str(current_avg))
            current_avg = [x  if x!=0.0 else 1e-9 for x in current_avg]
            avg_losses += [current_avg]
            '''
            loss = multi_task_loss([outputs["vocal_stem"],outputs["bass_stem"],outputs["drums_stem"],outputs["other_stem"]], 
                               [stems["vocal"],stems["bass"],stems["drums"],stems["other"]])
            '''
            running_vloss += loss.item()
            '''
            o_vocal = librosa.griffinlim(stems["vocal"].detach().cpu().numpy())
            o_bass = librosa.griffinlim(stems["bass"].detach().cpu().numpy())
            o_drums = librosa.griffinlim(stems["drums"].detach().cpu().numpy())
            o_other = librosa.griffinlim(stems["other"].detach().cpu().numpy())
            
            y_vocal = librosa.griffinlim(outputs["vocal_stem"].detach().cpu().numpy())
            y_bass = librosa.griffinlim(outputs["bass_stem"].detach().cpu().numpy())
            y_drums = librosa.griffinlim(outputs["drums_stem"].detach().cpu().numpy())
            y_other = librosa.griffinlim(outputs["other_stem"].detach().cpu().numpy())
            reference_sources = np.concatenate((np.expand_dims(o_vocal, axis=0), 
                                                np.expand_dims(o_bass, axis=0),
                                                np.expand_dims(o_drums, axis=0),
                                                np.expand_dims(o_other, axis=0)
                                               ), axis=0)
            estimated_sources = np.concatenate((np.expand_dims(y_vocal, axis=0), 
                                                np.expand_dims(y_bass, axis=0),
                                                np.expand_dims(y_drums, axis=0),
                                                np.expand_dims(y_other, axis=0)
                                               ), axis=0)
            (sdr, _, _, _) = mir_eval.separation.bss_eval_sources(reference_sources, estimated_sources, False)
            
            running_sdr+=sdr
            '''
            print(f"testing batch {i+1}, loss: {loss.item()}")

        avg_vloss = running_vloss / (i + 1)
        losses = 'LOSS epoch {}: train {} | valid {}'.format(epoch,avg_loss, avg_vloss)
        print(losses)
        avg_loss_history+=[losses]
        scheduler.step(avg_vloss)
        
        # Track best performance, and save the model's state
        if avg_vloss < best_vloss :
            print("new best found, saving checkpoint")
            best_checkpoint_path = '/kaggle/working/model_kaggle_{}_{}'.format(timestamp, epoch)
            best_vloss = avg_vloss
            torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss_state_dict': multi_task_loss.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict()
                    }, best_checkpoint_path)
            #torch.save(model.state_dict(), model_path)
        else : 
            print("not the best compared to last")
print("\n".join(avg_loss_history))

In [None]:
torch.cuda.empty_cache()
gc.collect()

model2 = UNet().to(device)
state_dict = torch.load(best_checkpoint_path)
print("loaded for inference: "+best_checkpoint_path)
model2.load_state_dict(state_dict['model_state_dict'])
model2.eval()

client_wav = "/kaggle/input/musics/shape_of_you.mp3"
wav_name = client_wav.split("/")[-1].split(".")[0]
audio, sr = librosa.load(client_wav)
S = np.abs(librosa.stft(audio))
magnitude, phase = librosa.magphase(S)
phase = phase[:freq_bins,:tframe]
mixture = torch.from_numpy(magnitude[:freq_bins,:tframe])
mixture = mixture.to(device)
to_pad = tframe-mixture.shape[1]
if to_pad<0 :
    if to_pad%2==0 :
        mixture = F.pad(mixture,(to_pad//2,to_pad//2),mode='constant',value=0)
    else :
        mixture = F.pad(mixture,(to_pad//2,to_pad//2+1),mode='constant',value=0)
mixture = mixture.unsqueeze(0)
outputs = model2(mixture)
bass = outputs["vocal_stem"].detach().cpu().numpy()*phase
drums = outputs["bass_stem"].detach().cpu().numpy()*phase
other = outputs["drums_stem"].detach().cpu().numpy()*phase
vocal = outputs["other_stem"].detach().cpu().numpy()*phase

torch.cuda.empty_cache()
gc.collect()

print(bass.shape)
print(drums.shape)
print(bass.shape)
print(other.shape)

y_inv_bass = np.ravel(librosa.griffinlim(bass))
y_inv_drums = np.ravel(librosa.griffinlim(drums))
y_inv_other = np.ravel(librosa.griffinlim(other))
y_inv_vocal = np.ravel(librosa.griffinlim(vocal))

print("converted to wav")

sf.write("/kaggle/working/"+wav_name+"_bass.wav", y_inv_bass, sr)
sf.write("/kaggle/working/"+wav_name+"_drums.wav", y_inv_drums, sr)
sf.write("/kaggle/working/"+wav_name+"_other.wav", y_inv_other, sr)
sf.write("/kaggle/working/"+wav_name+"_vocal.wav", y_inv_vocal, sr)

print("end")