# **Installing Required Libraries**

this GAN arch is taken from this git-hub repo : https://github.com/mostafaelaraby/cyclic-gan-music-source-separation

In [None]:
!pip install pescador==0.1.3
!pip install museval

In [None]:
!pip install lmdb
!pip install joblib==1.0.0

In [None]:
!pip install torchsummary

# **Importing all the Necessary Libraries**

In [None]:
import os
import time
import math 
import random
import librosa
import librosa.display
import soundfile as sf
import lmdb

import numpy as np
from torch.utils import data

import glob
from scipy.io.wavfile import read as wavread #TODO needs to check how to make the reader faster
import pescador
import torch
import torch.nn as nn
import torch.nn.functional as F  
import torch.optim as optim
from torch.autograd import grad, Variable
from torchsummary import summary  
import logging 

import matplotlib.pyplot as plt
import matplotlib.animation as animation
import matplotlib
import copy
from IPython.display import HTML

from tqdm import tqdm_notebook as tqdm

# **Config**


In [None]:
#############################
# Model Params
#############################
model_prefix = 'exp_musdb_1_wide_unpaired_ralsgan_4'#'exp_large_1' # name of the model to be saved
n_iterations = 100000   ### 1000000
# for the cyclic gan use these param with ncritic 1
# but for training wavegan model it is better to use ttur lr_g = 1e-4 lr_d=3e-4 and n_critic=1
lr_g = 2e-4
lr_d = 2e-4
beta1 = 0.5
beta2 = 0.9
decay_lr = False
generator_batch_size_factor = 1 # in some cases we might try to update the generator with double batch size used in the discriminator
n_critic = 5 # update generator every n_critic steps 
# gradient penalty regularization factor.
p_coeff = 10
batch_size = 10
noise_latent_dim = 100 
model_capacity_size = 64 # reduce the capacity to 32 for faster training also in case of generating larger window size

In [None]:
# rate of storing validation and costs params
store_cost_every = 300
progress_bar_step_iter_size = 400

In [None]:
# backup params
take_backup = True
backup_every_n_iters = 10000     #1000
save_samples_every = 10000    #1000

# folder including data under each folder train, valid and test subfolders
target_signals_dir = '/kaggle/input/nepali-music-source-seperation/nepali_music_source_seperation/train/vocals' #'/kaggle/input/to-delete-renamed/Newari_song/Newari_song/Bansuri'
other_signals_dir  = '/kaggle/input/nepali-music-source-seperation/nepali_music_source_seperation/train/mixture'  #'/kaggle/input/to-delete-renamed/Newari_song/Newari_song/mixture'

output_dir = '/kaggle/working/output'

In [None]:
#############################
# Audio Reading Params
#############################
window_length = 16384 #[16384, 32768, 65536] change model_capacity size to 32 in case of slice>16384
sampling_rate = 16000
normalize_audio = True 

In [None]:
#############################
# Torch Init and seed setting
#############################
cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
# update the seed
manual_seed = 2014 #@param {type: "integer"}
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)
np.random.seed(manual_seed)
if cuda:
    torch.cuda.manual_seed(manual_seed)
    torch.cuda.empty_cache()

In [None]:
#############################
# Logger init
#############################
LOGGER = logging.getLogger('/kaggle/working/wavegan')
LOGGER.setLevel(logging.DEBUG)

## Utils

### File Utils

In [None]:
def get_recursive_files(folderPath,ext):
		results  = os.listdir(folderPath)
		outFiles = [] 
		for file in results:
			if os.path.isdir(os.path.join(folderPath,file)):
				outFiles+=get_recursive_files(os.path.join(folderPath,file),ext)
			elif file.endswith(ext) :
				outFiles.append(os.path.join(folderPath,file))
			
		return outFiles

def make_path(output_path):
    if not os.path.isdir(output_path):
        os.makedirs(output_path)
    return output_path

###  Wav files utils

In [None]:
#Fast loading used with wav files only of 8 bits
def load_wav(wav_file_path, fast_loading=False ):
    try:
        if fast_loading:
            file_sampling_rate, audio_data = wavread(wav_file_path)
            if file_sampling_rate is not None and sampling_rate != file_sampling_rate:
                raise NotImplementedError('Scipy cannot resample audio.')
            if audio_data.dtype == np.int16:
                audio_data = audio_data.astype(np.float32)
                audio_data /= 32768.
            elif audio_data.dtype == np.float32:
                audio_data = np.copy(audio_data)
            else:
                raise NotImplementedError('Scipy cannot process atypical WAV files.')
            
        else:
            audio_data, _ = librosa.load(wav_file_path, sr=sampling_rate)

        if normalize_audio:
            # Clip magnitude
            max_mag = np.max(np.abs(audio_data))
            if max_mag > 1:
                audio_data /= max_mag
    except Exception as e:
        print(wav_file_path)
        print(str(e))
        LOGGER.error("Could not load {}: {}".format(wav_file_path, str(e)))
        raise e
    audio_len = len(audio_data)
    if audio_len < window_length:
        pad_length = window_length - audio_len
        left_pad = pad_length // 2
        right_pad = pad_length - left_pad 
        audio_data = np.pad(audio_data, (left_pad, right_pad), mode='constant')
        
    
    return audio_data.astype('float32')


def save_samples(epoch_samples, epoch ,prefix=''):
    """
    Save output samples.
    """
    sample_dir = make_path(os.path.join(output_dir, str(epoch)))

    for idx, sample in enumerate(epoch_samples):
        output_path = os.path.join(sample_dir, "{}_{}.wav".format(prefix,idx+1))
        sample = sample[0]
        sf.write(output_path, sample, sampling_rate)

In [None]:
def sample_audio(audio_data, start_idx=None, end_idx=None):
    audio_len = len(audio_data)
    if audio_len == window_length:
        # If we only have a single 1*window_length audio, just yield.
        sample = audio_data
    else:
        # Sample a random window from the audio
        if start_idx is None or end_idx is None:
            start_idx = np.random.randint(0, (audio_len - window_length) // 2)
            end_idx = start_idx + window_length
        sample = audio_data[start_idx:end_idx]
    sample =  sample.astype('float32')
    assert not np.any(np.isnan(sample))
    return sample, start_idx, end_idx

def audio_generator(audio_data):
    audio_len = len(audio_data)
    n_iters = audio_len // window_length
    for i in range(n_iters+1):
        start_idx = i * window_length
        end_idx = start_idx  + window_length
        result = np.zeros(window_length)
        audio_size = audio_data[start_idx:end_idx].shape[0]
        result[:audio_size] = audio_data[start_idx:end_idx]
        yield result
    
    
def sample_buffer(buffer_data, start_idx=None, end_idx=None):
    audio_len = len(buffer_data) // 4
    if audio_len == window_length:
        # If we only have a single 1*window_length audio, just yield.
        sample = buffer_data
    else:
        # Sample a random window from the audio
        if start_idx is None or end_idx is None:
            start_idx = np.random.randint(0, (audio_len - window_length) // 2)
            end_idx = start_idx + window_length
        sample = buffer_data[start_idx * 4:end_idx*4]
    return sample, start_idx, end_idx


def wav_generator(file_path, mixing_signal_path):
    audio_data = load_wav(file_path)
    mixing_data = load_wav(mixing_signal_path)
    while True:
        sample, start_idx, end_idx = sample_audio(audio_data)
        mixing_sample, _, _ = sample_audio(mixing_data,start_idx, end_idx)

        mixing_ratio = np.random.uniform(0, 1)
        mixed_signal = mixing_ratio * mixing_sample + (1-mixing_ratio) * sample

        yield {'single':sample, 'mixed':mixed_signal}

def create_stream_reader(single_signal_file_list, other_signal_file_list):
    data_streams = []
    other_signal_len  = len(other_signal_file_list)
    for audio_path in single_signal_file_list:
        other_signal_indx = np.random.randint(0,other_signal_len)
        stream = pescador.Streamer(wav_generator, audio_path, other_signal_file_list[other_signal_indx])
        data_streams.append(stream)
    mux = pescador.ShuffledMux(data_streams)
    batch_gen = pescador.buffer_stream(mux, batch_size)
    return batch_gen

### Model Utils

In [None]:

def sample_noise(size):
    z = torch.FloatTensor(size,noise_latent_dim).to(device)
    z.data.normal_() # generating latent space based on normal distribution
    return z

def weights_init(m): 
    if isinstance(m, nn.Conv1d):
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
        m.bias.data.fill_(0)
    elif isinstance(m, nn.Linear):
        m.bias.data.fill_(0)

def update_optimizer_lr(optimizer,lr,decay):
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr* decay   


def gradients_status(model,flag):
    for p in model.parameters():
        p.requires_grad = flag

In [None]:

# used to replay some history generate to the discriminator to avoid greedy disc.
class ReplayBuffer():
    def __init__(self, max_size=50):
        assert (max_size > 0), 'Empty buffer or trying to create a black hole. Be careful.'
        self.max_size = max_size
        self.data = []

    def push_and_pop(self, data):
        to_return = []
        for element in data.data:
            element = torch.unsqueeze(element, 0)
            if len(self.data) < self.max_size:
                self.data.append(element)
                to_return.append(element)
            else:
                if random.uniform(0,1) > 0.5:
                    i = random.randint(0, self.max_size-1)
                    to_return.append(self.data[i].clone())
                    self.data[i] = element
                else:
                    to_return.append(element)
        return Variable(torch.cat(to_return))

### Data Loader and Sampler

In [None]:
def numpy_to_tensor(numpy_array):
    numpy_array = numpy_array[:, np.newaxis, :]
    return torch.Tensor(numpy_array).to(device)

#############################
# Creating Data Loader and Sampler
#############################
class WavDataLoader():
    def __init__(self, folder_path, other_signals_folder, audio_extension='wav'):
        self.signal_paths = get_recursive_files(folder_path, audio_extension)
        self.mixed_wav_files = get_recursive_files(other_signals_folder, audio_extension)
        self.data_iter = None
        self.initialize_iterator()

    def initialize_iterator(self):
        data_iter = create_stream_reader(self.signal_paths, self.mixed_wav_files)
        self.data_iter = iter(data_iter)

    def __len__(self):
        return len(self.signal_paths)
        


    def __iter__(self):
        return self

    def __next__(self):
        x =  next(self.data_iter)
        return (numpy_to_tensor(x['single']), numpy_to_tensor(x['mixed']), numpy_to_tensor(x['foreground']))



In [None]:
class LMDBWavLoader(data.Dataset):
    def __init__(self, lmdb_file_path, is_test=False):
        self.env = lmdb.open(lmdb_file_path, max_readers=1, readonly=True, lock=False,
                             readahead=False, meminit=False)
        self.datum = datanum_pb2.DataNum()
        self.is_test = is_test
        

    def __len__(self):
        n_entries =  int(self.env.stat()['entries'])
        return n_entries


    def __getitem__(self, data_indx):
        index = None
        audio_indx = None
        index = data_indx
        with self.env.begin(write=False) as cursor:
            raw_datum = cursor.get('{:08}'.format(index).encode('ascii'))
        self.datum.ParseFromString(raw_datum)
        # float is represented by 4 bytes
        start_idx = None
        end_idx = None
        if self.is_test:
            return np.array(np.frombuffer(self.datum.vocals, dtype=np.float32)).reshape(-1),np.array(np.frombuffer(self.datum.mixture, dtype=np.float32)).reshape(-1)
                
        mixture, start_idx, end_idx = sample_buffer(self.datum.mixture, start_idx, end_idx)
        mixture = np.array(np.frombuffer(mixture, dtype=np.float32)).reshape(1,-1)
        vocals, _, _ = sample_buffer(self.datum.vocals, start_idx, end_idx)
        vocals = np.frombuffer(vocals, dtype=np.float32).reshape(1,-1)
        return vocals, mixture

### Visualize utils

In [None]:
def visualize_audio(audio_tensor, is_monphonic=False):
    # takes a batch ,n channels , window length and plots the spectogram
    input_audios = audio_tensor.detach().cpu().numpy()
    plt.figure(figsize=(18, 50))
    for i , audio in enumerate(input_audios):
        plt.subplot(10, 2, i+1)
        if is_monphonic:
            plt.title('Monophonic %i' % (i+1))
            librosa.display.waveplot(audio[0], sr=sampling_rate)
        else:
            D = librosa.amplitude_to_db(np.abs(librosa.stft(audio[0])), ref=np.max)
            librosa.display.specshow(D, y_axis='linear')
            plt.colorbar(format='%+2.0f dB')
            plt.title('Linear-frequency power spectrogram %i' % (i+1))
    plt.show()

def visualize_loss(loss_1, loss_2, first_legend, second_legend, y_label):
    plt.figure(figsize=(10,5))
    plt.title("{} and {} Loss During Training".format(first_legend, second_legend))
    plt.plot(loss_1,label=first_legend)
    plt.plot(loss_2,label=second_legend)
    plt.xlabel("iterations")
    plt.ylabel(y_label)
    plt.grid(True)
    plt.tight_layout()
    plt.legend()
    plt.show()

In [None]:
def latent_space_interpolation(model, n_samples=10 ):
    z_test = sample_noise(2)
    with torch.no_grad():  
        interpolates = []
        for alpha in np.linspace(0, 1, n_samples):
            interpolate_vec = alpha * z_test[0] + ((1 - alpha) * z_test[1])
            interpolates.append(interpolate_vec)     
            
        interpolates = torch.stack(interpolates)
        generated_audio = model(interpolates)
    visualize_audio(generated_audio, True)

In [None]:
if os.path.isdir(output_dir):
  os.system('rm -r {}'.format(output_dir))
make_path(output_dir)


# **Data Loading**

In [None]:
!git clone https://github.com/mostafaelaraby/cyclic-gan-music-source-separation.git
    
### this code in some dependecy file from this git hub repo

In [None]:
import sys
sys.path.append('/kaggle/working/cyclic-gan-music-source-separation')

In [None]:
from protocol_buffer import datanum_pb2

In [None]:
import numpy as np
import lmdb
## import caffe
## from   utils  import get_recursive_files, sample_audio, load_wav, manual_seed
manual_seed=2014 
sampling_rate=16000 
window_length=16384
from sklearn.model_selection import train_test_split
import wave
import librosa
from tqdm import tqdm
 
## protoc -I=$SRC_DIR --python_out=$DST_DIR/datanum_pb2.py $SRC_DIR/datanum.proto
## to create protoc


def get_map_size(files):
    return load_wav(files[0]).nbytes * 10 * (len(files) +2)

def get_silent_set(input_audio):
    indices = np.where(input_audio==0)[0]
    index_sets = []
    window = 16384
    counter = 0
    prev_index = -1
    first_index = -1
    for index in indices:
        if counter==0:
            first_index = index
        if index - prev_index == 1:
            counter +=1
        else:
            if counter>window:
                index_sets.append((first_index, prev_index, counter))
            counter = 0
        prev_index = index
    if counter>window:
        index_sets.append((first_index, prev_index, counter))
    return index_sets

def remove_silence(input_audio, index_sets):
    for silent_indices in index_sets:
        first_indx = silent_indices[0]
        last_idx = silent_indices[1]
        input_audio = np.delete(input_audio, [indx for indx in range(first_indx, last_idx)])
    return input_audio
    
## adapted from 
## https://github.com/francesclluis/source-separation-wavenet/blob/6d89618c77d38960c3996219f329e5806573799b/util.py#L220
## return start and ending indices first one is the start second on is the finish

def get_sequence_with_singing_indices(full_sequence, chunk_length = 800):

    signal_magnitude = np.abs(full_sequence)

    chunks_energies = []
    for i in range(0, len(signal_magnitude), chunk_length):
        chunks_energies.append(np.mean(signal_magnitude[i:i + chunk_length]))

    threshold = np.max(chunks_energies) * .1
    chunks_energies = np.asarray(chunks_energies)
    chunks_energies[np.where(chunks_energies < threshold)] = 0
    onsets = np.zeros(len(chunks_energies))
    onsets[np.nonzero(chunks_energies)] = 1
    onsets = np.diff(onsets)

    start_ind = np.squeeze(np.where(onsets == 1))
    finish_ind = np.squeeze(np.where(onsets == -1))

    if finish_ind[0] < start_ind[0]:
        finish_ind = finish_ind[1:]

    if start_ind[-1] > finish_ind[-1]:
        start_ind = start_ind[:-1]

    indices_inici_final = np.insert(finish_ind, np.arange(len(start_ind)), start_ind)

    return np.squeeze((np.asarray(indices_inici_final) + 1) * chunk_length)

def write_lmdb(out_file_name, data_list):
    lmdb_output = lmdb.open(out_file_name, map_size=get_map_size(data_list))
    with lmdb_output.begin(write=True) as txn:
        # txn is a Transaction object
        for audio_indx, audio_path in enumerate(tqdm(data_list)):
            mixed_data = load_wav(audio_path).astype('float32')
            vocals_data = load_wav(audio_path.replace('mixture','vocals')).astype('float32')   #Bansuri---->vocals
            '''
            # to remove zeros from mixed and vocals based on vocals
            silent_set = get_silent_set(vocals_data)
            mixed_data = remove_silence(mixed_data, silent_set)
            vocals_data = remove_silence(vocals_data, silent_set)
            '''

            vocals_indices = get_sequence_with_singing_indices(vocals_data, 800)
            

            datum = datanum_pb2.DataNum()
            datum.mixture = mixed_data.tobytes()
            datum.vocals = vocals_data.tobytes() 
            datum.vocals_indices = vocals_indices.tobytes() # used to store the indices having voice
            str_id = '{:08}'.format(audio_indx)
            txn.put(str_id.encode('ascii'), datum.SerializeToString()) 

def create_lmdb(folder_name, out_file_name, is_train=True):
    mixture_audio_train = get_recursive_files(folder_name,'.wav')
    #     print(audio_train)
    audio_valid = True
    if is_train:
        # create validation set
        mixture_audio_train, mixture_audio_valid = train_test_split(mixture_audio_train, test_size = 0.15, random_state = manual_seed)
    write_lmdb(out_file_name, mixture_audio_train)

    if audio_valid:
        write_lmdb(out_file_name.replace('_train','')+'_valid', mixture_audio_valid)

parent_folder = "/kaggle/input/nepali-music-source-seperation/nepali_music_source_seperation/train/mixture" ##"/kaggle/input/newari-music/Newari_song/mixture"
create_lmdb(parent_folder, 'musdb_train')

### parent_folder = "/kaggle/input/musdb18ying/test"
### create_lmdb(parent_folder, 'musdb_test', False)


### Testing Proto buffer reading
for lmdb_name in ['musdb_train', 'musdb_valid']:
    env = lmdb.open(lmdb_name, readonly=True)
    with env.begin() as txn:
        raw_datum = txn.get(b'00000000')

    datum = datanum_pb2.DataNum()
    datum.ParseFromString(raw_datum)

    mixture = np.fromstring(datum.mixture, dtype=np.float32)
    vocals = np.fromstring(datum.vocals, dtype=np.float32)
    vocals_indices = np.fromstring(datum.vocals_indices, dtype=np.int32)
    print(mixture.shape)
    print(vocals.shape)
    print(vocals_indices)


In [None]:
# loading data using data loader
train_dataset = LMDBWavLoader('/kaggle/working/musdb_train')  ##'/kaggle/working/musdb_train'
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                        shuffle=True, num_workers=1,drop_last=True,pin_memory=True)
val_dataset = LMDBWavLoader('/kaggle/working/musdb_valid')
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size,
                                        shuffle=True, num_workers=1,drop_last=True,pin_memory=True)
print(val_dataset,train_dataset,train_loader,val_loader)

## Model

In [None]:
class Transpose1dLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding=11, upsample=None, output_padding=1, use_batch_norm = False):
        super(Transpose1dLayer, self).__init__()
        self.upsample = upsample
        reflection_pad = nn.ConstantPad1d(kernel_size // 2, value=0)
        conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride)
        conv1d.weight.data.normal_(0.0, 0.02)
        Conv1dTrans = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride, padding, output_padding)
        batch_norm = nn.BatchNorm1d(out_channels)
        if self.upsample:
            operation_list = [
                reflection_pad,
                conv1d  
            ]
        else:
            operation_list = [
                Conv1dTrans
            ]

        if use_batch_norm:
            operation_list.append(batch_norm)
        self.transpose_ops = nn.Sequential(*operation_list )

    def forward(self, x):
        if self.upsample:
            # recommended by wavgan paper to use nearest upsampling
            x = nn.functional.interpolate(x,scale_factor=self.upsample, mode='nearest')
        return self.transpose_ops(x)


In [None]:
class Conv1D(nn.Module):
    def __init__(self,input_channels, output_channels, kernel_size,alpha = 0.2,shift_factor=2, stride=4, padding=11, use_batch_norm=False, drop_prob = 0):
        super(Conv1D, self).__init__()
        self.conv1d = nn.Conv1d(input_channels, output_channels, kernel_size, stride=stride, padding=padding)
        self.batch_norm = nn.BatchNorm1d(output_channels)
        self.phase_shuffle =  PhaseShuffle(shift_factor)
        self.alpha = alpha
        self.use_batch_norm = use_batch_norm
        self.use_phase_shuffle = shift_factor==0
        self.use_drop = drop_prob>0
        self.dropout = nn.Dropout2d(drop_prob)
    
    def forward(self, x):
        x = self.conv1d(x)
        if self.use_batch_norm:
            x = self.batch_norm(x)
        x = F.leaky_relu(x, negative_slope=self.alpha)
        if self.use_phase_shuffle: 
            x = self.phase_shuffle(x)
        if self.use_drop:
            x = self.dropout(x)
        return x

In [None]:
class PhaseShuffle(nn.Module):
    """
    Performs phase shuffling, i.e. shifting feature axis of a 3D tensor
    by a random integer in {-n, n} and performing reflection padding where
    necessary.
    """
    # Copied from https://github.com/jtcramer/wavegan/blob/master/wavegan.py#L8
    def __init__(self, shift_factor):
        super(PhaseShuffle, self).__init__()
        self.shift_factor = shift_factor

    def forward(self, x):
        if self.shift_factor == 0:
            return x
        # uniform in (L, R)
        k_list = torch.Tensor(x.shape[0]).random_(0, 2 * self.shift_factor + 1) - self.shift_factor
        k_list = k_list.numpy().astype(int)

        # Combine sample indices into lists so that less shuffle operations
        # need to be performed
        k_map = {}
        for idx, k in enumerate(k_list):
            k = int(k)
            if k not in k_map:
                k_map[k] = []
            k_map[k].append(idx)

        # Make a copy of x for our output
        x_shuffle = x.clone()

        # Apply shuffle to each sample
        for k, idxs in k_map.items():
            if k > 0:
                x_shuffle[idxs] = F.pad(x[idxs][..., :-k], (k, 0), mode='reflect')
            else:
                x_shuffle[idxs] = F.pad(x[idxs][..., -k:], (0, -k), mode='reflect')

        assert x_shuffle.shape == x.shape, "{}, {}".format(x_shuffle.shape,
                                                       x.shape)
        return x_shuffle

### Generator Model

In [None]:
class WaveGANGenerator(nn.Module):
    def __init__(self, model_size=64, num_channels=1,
                  verbose=False,
                  upsample=True, slice_len=16384, use_batch_norm = False):
        super(WaveGANGenerator, self).__init__()
        assert slice_len in [16384, 32768, 65536] # used to predict longer utterances

        self.model_size = model_size  # d 
        self.verbose = verbose
        self.use_batch_norm = use_batch_norm

        self.dim_mul = 16 if slice_len == 16384 else 32

        self.fc1 = nn.Linear(noise_latent_dim, 4*4* model_size * self.dim_mul)
        self.bn1 = nn.BatchNorm1d(num_features= model_size * self.dim_mul)

        stride = 4
        if upsample:
            stride = 1
            upsample = 4
        
        deconv_layers = [
            Transpose1dLayer( self.dim_mul* model_size, (self.dim_mul* model_size) //2, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm),
            Transpose1dLayer((self.dim_mul* model_size) //2, (self.dim_mul* model_size) //4, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm),
            Transpose1dLayer( (self.dim_mul* model_size) //4,  (self.dim_mul* model_size) //8, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm),
            Transpose1dLayer( (self.dim_mul* model_size) //8,  (self.dim_mul* model_size) //16, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm),
        ]
        

        if slice_len== 16384:
            deconv_layers.append( Transpose1dLayer((self.dim_mul* model_size) //16, num_channels, 25, stride, upsample=upsample))
        elif slice_len == 32768 :
            deconv_layers +=[ 
                Transpose1dLayer((self.dim_mul* model_size) //16, model_size, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm)
                ,Transpose1dLayer(model_size, num_channels, 25, 2, upsample=upsample)
            ]
        elif slice_len == 65536:
            deconv_layers +=[
                Transpose1dLayer((self.dim_mul* model_size) //16, model_size, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm)
                ,Transpose1dLayer(model_size, num_channels, 25, stride, upsample=upsample)
            ]
        else:
            raise ValueError('slice_len {} value is not supported'.format(slice_len))
        

        self.deconv_list = nn.ModuleList(deconv_layers)
        for m in self.modules():
            if isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):
        x = self.fc1(x).view(-1, self.dim_mul* self.model_size, 16)
        if self.use_batch_norm:
            x = self.bn1(x)
        x = F.relu(x)
        if self.verbose:
            print(x.shape)

        for deconv in self.deconv_list[:-1]:
            x = F.relu(deconv(x))
            if self.verbose:
                print(x.shape)
        output = torch.tanh(self.deconv_list[-1](x))
        return output

### Discriminator Model

In [None]:
class WaveGANDiscriminator(nn.Module):
    def __init__(self, model_size=64, ngpus=1, num_channels=1, shift_factor=2,
                 alpha=0.2, verbose=False, slice_len=16384, use_batch_norm = False):
        super(WaveGANDiscriminator, self).__init__()
        assert slice_len in [16384, 32768, 65536] # used to predict longer utterances
        
        self.model_size = model_size  # d
        self.ngpus = ngpus
        self.use_batch_norm = use_batch_norm
        self.num_channels = num_channels  # c
        self.shift_factor = shift_factor  # n
        self.alpha = alpha
        self.verbose = verbose


        conv_layers = [
            Conv1D(num_channels, model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(model_size, 2 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(2 * model_size, 4 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(4 * model_size, 8 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(8 * model_size, 16 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=0 if slice_len==16384 else shift_factor)
        ]
        self.fc_input_size = 256 * model_size
        if slice_len == 32768 :
            conv_layers.append(
                 Conv1D(16 * model_size, 32 * model_size, 25, stride=2, padding=11,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=0)
            ) 
            self.fc_input_size = 480 * model_size
        elif slice_len == 65536:
            conv_layers.append(
                 Conv1D(16 * model_size, 32 * model_size, 25, stride=4, padding=11,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=0)
            )
            self.fc_input_size = 512 * model_size
        
        self.conv_layers = nn.ModuleList(conv_layers)
        
        self.fc1 = nn.Linear(self.fc_input_size, 1)

        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x):
        for conv in self.conv_layers:
            x= conv(x)
            if self.verbose:
                print(x.shape)
        x = x.view(-1, self.fc_input_size)
        if self.verbose:
            print(x.shape)

        return self.fc1(x)

### Cyclic Generator with resblocks

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_features, use_batch_norm=True, alpha=0.2, shift_factor=2):
        super(ResidualBlock, self).__init__()
        conv_blocks = [
            Conv1D(in_features, in_features, 21, stride=1, padding=10, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(in_features, in_features, 21, stride=1, padding=10, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
        ]
        self.conv_blocks = nn.ModuleList(conv_blocks)
    
    def forward(self, x):
        #down_sampled_x = nn.functional.interpolate(x,scale_factor=0.25, mode='nearest')
        output = x
        for conv in self.conv_blocks:
            output = conv(output)
        return x + output

In [None]:
class CyclicWaveGanGenerator(nn.Module):
    def __init__(self,model_size=64,  num_channels=1, shift_factor=2,
                 alpha=0.2, verbose=False, slice_len=16384, use_batch_norm = False, upsample=True):
        super(CyclicWaveGanGenerator, self).__init__()
        assert slice_len in [16384, 32768, 65536] # used to predict longer utterances
        self.dim_mul = 16 if slice_len == 16384 else 32
        self.verbose = verbose
        encoder_conv = [ 
            Conv1D(num_channels, model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(model_size, 2 * model_size, 25, stride=4, padding=13 if slice_len==32768 else 11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            #Conv1D(2 * model_size, 4*model_size , 25, stride=4, padding=13 if slice_len==32768 else 11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
        ]
        n_resblocks = 8
        if slice_len == 32768:
            encoder_conv.append(
                Conv1D(2*model_size , (self.dim_mul* model_size) //8, 25, stride=2, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            )
            n_resblocks = 9
        elif slice_len == 65536:
            encoder_conv.append(
                Conv1D(2*model_size , (self.dim_mul* model_size) //8, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            )
            n_resblocks = 9
        
        self.encoder = nn.Sequential(*encoder_conv)
        # encoder output 1 256 256
        
        transformation = []
        for _ in range(n_resblocks):
            transformation.append(ResidualBlock((self.dim_mul* model_size) //8))
        self.transformation = nn.Sequential(*transformation)

        # Upsampling
        stride = 4
        if upsample:
            stride = 1
            upsample = 4
        
        deconv_layers = [
            Transpose1dLayer( (self.dim_mul* model_size) //8,  (self.dim_mul* model_size) //16, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm),
            #Transpose1dLayer( (self.dim_mul* model_size) //8,  (self.dim_mul* model_size) //16, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm),
        ]
        

        if slice_len== 16384:
            deconv_layers.append( Transpose1dLayer((self.dim_mul* model_size) //16, num_channels, 25, stride, upsample=upsample))
        elif slice_len == 32768 :
            deconv_layers +=[ 
                Transpose1dLayer((self.dim_mul* model_size) //16, model_size, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm)
                ,Transpose1dLayer(model_size, num_channels, 25, 2, upsample=upsample)
            ]
        elif slice_len == 65536:
            deconv_layers +=[
                Transpose1dLayer((self.dim_mul* model_size) //16, model_size, 25, stride, upsample=upsample,use_batch_norm=use_batch_norm)
                ,Transpose1dLayer(model_size, num_channels, 25, stride, upsample=upsample)
            ]
        else:
            raise ValueError('slice_len {} value is not supported'.format(slice_len))
        self.upsampling = nn.Sequential( *deconv_layers)

    def forward(self, x):
        x = self.encoder(x)
        if self.verbose:
            print(x.shape)
        x = self.transformation(x)
        if self.verbose:
            print(x.shape)
        x = self.upsampling(x)
        if self.verbose:
            print(x.shape)
        return x

### Discriminator Model With Latent Code Eval

In [None]:
class WaveGANDiscriminatorSep(nn.Module):
    def __init__(self, model_size=64,  num_channels=1, shift_factor=2,
                 alpha=0.2, verbose=False, slice_len=16384, use_batch_norm = False, drop_prob=0.2):
        super(WaveGANDiscriminatorSep, self).__init__()
        assert slice_len in [16384, 32768, 65536] # used to predict longer utterances
        
        self.verbose = verbose
        self.model_size = model_size
 
        conv_layers = [
            Conv1D(num_channels, model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor, drop_prob = drop_prob),
            Conv1D(model_size, 2 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor, drop_prob = drop_prob),
            Conv1D(2 * model_size, 4 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor, drop_prob = drop_prob),
        ]

        if slice_len == 32768 :
            conv_layers.append(
                 Conv1D(4 * model_size, 4 * model_size, 25, stride=2, padding=12,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            ) 
        elif slice_len == 65536:
            conv_layers.append(
                 Conv1D(4 * model_size, 4 * model_size, 25, stride=4, padding=12,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            )
        self.infer_input_audio = nn.ModuleList(conv_layers)

        self.move_z_space = nn.Linear(noise_latent_dim, 64 * model_size)
        self.infer_z = nn.ModuleList([
            Conv1D( 1, model_size//8, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, drop_prob = drop_prob),
            Conv1D(model_size//8, model_size//4, 25, stride=4, padding=11,use_batch_norm=use_batch_norm, drop_prob = drop_prob),
        ])
        
        conv_joint_layers = [
            Conv1D((4 * model_size)+(model_size//4), 8 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor, drop_prob = drop_prob),
            Conv1D(8 * model_size, 16 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=0, drop_prob = 0 if slice_len==16384 else drop_prob)
        ]
        
        self.fc_input_size = 8 * 32 * model_size
        if slice_len == 32768 :
            conv_joint_layers.append(
                 Conv1D(16 * model_size, 32 * model_size, 25, stride=2, padding=11,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            ) 
            self.fc_input_size = 7 * 32 * model_size
        elif slice_len == 65536:
            conv_joint_layers.append(
                 Conv1D(16 * model_size, 32 * model_size, 25, stride=4, padding=11,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            )
            self.fc_input_size = 4 * 32 * model_size
        self.infer_join = nn.ModuleList(conv_joint_layers)
        self.fc1 = nn.Linear(self.fc_input_size, 1)
        

        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)

    def forward(self, x, z):
        for conv in self.infer_input_audio:
            x = conv(x)
            if self.verbose:
                print(x.shape)
        if self.verbose:
            print(x.shape)
        z = self.move_z_space(z).view(-1, 1, self.model_size * 64)
        if self.verbose:
            print(z.shape)
        for conv in self.infer_z:
            z = conv(z)
            if self.verbose:
                print(z.shape)
        
        # join z and x
        output = torch.cat([x, z], dim=1)
        if self.verbose:
            print(output.shape)
        for conv in self.infer_join:
            output = conv(output)
            if self.verbose:
                print(output.shape)
        output = output.view(-1, self.fc_input_size)
        if self.verbose:
            print(output.shape)
        return self.fc1(output)

### Latent Space Encoder

In [None]:
class WaveGanEncoder(nn.Module):
    def __init__(self,model_size=64,  num_channels=1, shift_factor=2,
                 alpha=0.2, verbose=False, slice_len=16384, use_batch_norm = False):
        super(WaveGanEncoder, self).__init__()
        self.verbose = verbose
        conv_layers = [
            Conv1D(num_channels, model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(model_size, 2 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(2 * model_size, 4 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
            Conv1D(4 * model_size, 8 * model_size, 25, stride=4, padding=11, use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor),
        ]
        self.fc_input_size = 8 * model_size * 16
        if slice_len == 32768 :
            conv_layers.append(
                 Conv1D(8 * model_size, 16 * model_size, 25, stride=2, padding=11,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            ) 
            self.fc_input_size = 16 * model_size * 16
        elif slice_len == 65536:
            conv_layers.append(
                 Conv1D(8 * model_size, 16 * model_size, 25, stride=4, padding=11,use_batch_norm=use_batch_norm, alpha=alpha,shift_factor=shift_factor)
            )
            self.fc_input_size = 16 * model_size * 16
        self.conv_layers = nn.ModuleList(conv_layers)

        self.adaptive_pooling = nn.AdaptiveAvgPool1d(16) # needs to be checked later

        self.mu_linear = nn.Linear(self.fc_input_size, noise_latent_dim)
        self.logvar_linear = nn.Linear(self.fc_input_size, noise_latent_dim)
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight.data)
    
    def reparametrize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        z = mu + std * sample_noise(std.size(0))
        return z

    def bottleneck(self, h):
        mu, logvar = self.mu_linear(h), self.logvar_linear(h)
        z = self.reparametrize(mu, logvar)
        return z, mu, logvar

    def forward(self, x, inference=True):
        for conv in self.conv_layers:
            x= conv(x)
            if self.verbose:
                print(x.shape)
        x = self.adaptive_pooling(x)
        if self.verbose:
            print(x.shape)
        x = x.view(-1, self.fc_input_size)
        if self.verbose:
            print(x.shape)
        z, _, _ = self.bottleneck(x)
        return z

### Model Init and Testing

In [None]:
z = sample_noise(10)
generator_model = WaveGANGenerator(verbose=True, upsample=True,  use_batch_norm=False ,slice_len = window_length).to(device)
generated_audio_test = generator_model(z)
print(generated_audio_test.shape)

In [None]:
generator_model.verbose= False
summary(generator_model, input_size=(1,noise_latent_dim))

In [None]:
discriminator_model = WaveGANDiscriminator(verbose=True, use_batch_norm=True  ,slice_len = window_length).to(device)
discriminating_output = discriminator_model(generated_audio_test)
print(discriminating_output.shape)

In [None]:
discriminator_model.verbose= False
summary(discriminator_model, input_size=(1,window_length))

In [None]:
discriminator_model = WaveGANDiscriminatorSep(verbose=True  ,slice_len = window_length).to(device)
discriminating_output = discriminator_model(generated_audio_test, sample_noise(10))
print(discriminating_output.shape)

In [None]:
discriminator_model.verbose=False
summary(discriminator_model, input_size=[(1,window_length),(1,noise_latent_dim)])

In [None]:
encoder_model = WaveGanEncoder(verbose=True  ,slice_len = window_length).to(device)
latent_encoded = encoder_model(generated_audio_test)
print(latent_encoded.shape)

In [None]:
encoder_model.verbose=False
summary(encoder_model, input_size=(1,window_length))

In [None]:
cyclic_generator = CyclicWaveGanGenerator(slice_len=window_length, verbose = True).to(device)
generated = cyclic_generator(Variable(torch.rand(10,1,window_length)).to(device))
print(generated.shape)

In [None]:
cyclic_generator.verbose= False
summary(cyclic_generator, input_size=(1,window_length))

### Cyclic Gan

In [None]:
val_set = iter(train_loader)
print(val_set)
val_data = next(val_set)

In [None]:
class CycleGan(object):
    def __init__(self, train_loader, val_loader, validate=True):
        # train cyclic to just disable the model from single source to mixed as an experiment

        self.validate = validate

        self.val_g_cost = []
        self.train_g_cost = []
        self.cyclic_loss = []
        self.valid_reconstruction = []

        self.discriminator_loss = []

        self.generator = CyclicWaveGanGenerator(slice_len = window_length, model_size = model_capacity_size).to(device)
        self.generator.apply(weights_init)
        
        self.discriminator_1 = WaveGANDiscriminator(slice_len = window_length, model_size = model_capacity_size).to(device)
        self.discriminator_1.apply(weights_init)
    
        self.discriminator_2 = WaveGANDiscriminator(slice_len = window_length, model_size = model_capacity_size).to(device)
        self.discriminator_2.apply(weights_init)


        self.optimizer_d_1 = optim.Adam(self.discriminator_1.parameters(), lr=lr_d, betas=(beta1,  beta2))
        self.optimizer_d_2 = optim.Adam(self.discriminator_2.parameters(), lr=lr_d, betas=(beta1,  beta2))
        

        

        self.optimizer_g = optim.Adam(self.generator.parameters(), lr=lr_g, betas=(beta1,  beta2))

        self.train_loader = train_loader
        self.val_loader = val_loader          
    
    def apply_zero_grad(self):
        self.discriminator_1.zero_grad()
        self.discriminator_2.zero_grad()
        self.generator.zero_grad()

    def enable_gen_disable_disc(self):
        gradients_status(self.generator, True)
        gradients_status(self.discriminator_1, False)
        gradients_status(self.discriminator_2, False)
    
    def disable_all(self):
        gradients_status(self.generator, False)
        gradients_status(self.discriminator_1, False)
        gradients_status(self.discriminator_2, False)
    
  
    def train(self):
        real_label = 0.9
        progress_bar =  tqdm(total=n_iterations//progress_bar_step_iter_size)
        val_set = iter(self.val_loader)
        val_data = next(val_set)
        fixed_mixed_signal = val_data[1].to(device)
        fixed_single_signal = val_data[0].to(device)
        save_samples(fixed_mixed_signal.detach().cpu().numpy(), 'fixed_mixed' ) 
        save_samples(fixed_single_signal.detach().cpu().numpy(), 'fixed_single' ) 
        
        gan_model_name = 'gan_cyclic_single_2disc_{}.tar'.format(model_prefix)  #.tar
        
        first_iter = 0
        if take_backup and os.path.isfile(gan_model_name):
            if cuda:
                checkpoint = torch.load(gan_model_name)
            else:
                checkpoint = torch.load(gan_model_name, map_location='cpu')
            self.generator.load_state_dict(checkpoint['generator'])
            self.discriminator_1.load_state_dict(checkpoint['discriminator_1'])
            self.discriminator_2.load_state_dict(checkpoint['discriminator_2'])
            self.optimizer_d_1.load_state_dict(checkpoint['optimizer_d_1'])
            self.optimizer_d_2.load_state_dict(checkpoint['optimizer_d_2'])
            self.optimizer_g.load_state_dict(checkpoint['optimizer_g'])
            self.val_g_cost = checkpoint['val_g_cost']
            self.train_g_cost = checkpoint['train_g_cost']
            self.cyclic_loss = checkpoint['cyclic_loss']
            self.discriminator_loss = checkpoint['discriminator_loss']
            first_iter = checkpoint['n_iterations'] + 1
            for _ in range(0, first_iter, progress_bar_step_iter_size):
                progress_bar.update()
            
            
            self.generator.eval()
        criterion_GAN = nn.MSELoss()
        criterion_cycle = nn.L1Loss()
        criterion_identity = nn.L1Loss()

        target_real = Variable(torch.Tensor(batch_size,1).fill_(1.0), requires_grad=False).to(device)
        target_fake = Variable(torch.Tensor(batch_size,1).fill_(0.0), requires_grad=False).to(device)
        
        generated_single_buffer = ReplayBuffer()
        train_set = iter(self.train_loader)
        for iter_indx in range(first_iter, n_iterations): 
            self.generator.train()
            self.discriminator_1.train()
            self.discriminator_2.train()
            try:
                data = next(train_set)
            except StopIteration:
                train_set = iter(self.train_loader)
                data = next(train_set)
            
            # in case of unpaired data
            single_signal = data[0].to(device)
            mixed_signal = data[1].to(device)


            #############################
            # Training First Discriminator
            #############################
            self.apply_zero_grad()
            self.disable_all()
            gradients_status(self.discriminator_1, True)
            generated_single_signal = self.generator(single_signal)
            rest_of_signal = mixed_signal - generated_single_signal
            new_mixed_signal = rest_of_signal + single_signal

            # Real loss
            is_single_signal_r = self.discriminator_1(single_signal)
            #d_loss_real_1 = criterion_GAN(is_single_signal_r,target_real)

            # generated loss
            #generated_single_signal = generated_single_buffer.push_and_pop(generated_single_signal)
            is_single_signal_f = self.discriminator_1(generated_single_signal.detach())
            #d_loss_generated_1 = criterion_GAN(is_single_signal_f, target_fake)

            d_loss_1 = (torch.mean((is_single_signal_r - torch.mean(is_single_signal_f) - target_real) ** 2) +
                torch.mean((is_single_signal_f - torch.mean(is_single_signal_r) + target_real) ** 2))/2#(d_loss_real_1 + d_loss_generated_1)/2
            d_loss_1.backward()
            self.optimizer_d_1.step()
            #############################
            # Training Second Discriminator
            #############################
            self.apply_zero_grad()
            self.disable_all()
            gradients_status(self.discriminator_2, True)
            # Real loss
            is_mixed_signal_r = self.discriminator_2(mixed_signal)
            #d_loss_real_1 = criterion_GAN(is_mixed_signal_r,target_real)

            # generated loss
            #generated_single_signal = generated_single_buffer.push_and_pop(generated_single_signal)
            is_mixed_signal_f = self.discriminator_2(new_mixed_signal.detach())
            #d_loss_generated_1 = criterion_GAN(is_mixed_signal_f, target_fake)

            d_loss_2 = (torch.mean((is_mixed_signal_r - torch.mean(is_mixed_signal_f) - target_real) ** 2) +
                torch.mean((is_mixed_signal_f - torch.mean(is_mixed_signal_r) + target_real) ** 2))/2#(d_loss_real_1 + d_loss_generated_1)/2
            d_loss_2.backward()
            self.optimizer_d_2.step()

            #############################
            # Training  generator
            #############################
            self.apply_zero_grad()
            self.enable_gen_disable_disc()
            

            # Identity loss without it the model would make changes to input even without any need

            
            identity_loss_1 = criterion_identity(generated_single_signal, single_signal)
        
            # Gan Loss
            #generated_single_signal = self.generator(mixed_signal)
            is_single_signal_r = self.discriminator_1(single_signal)
            is_single_signal_f = self.discriminator_1(generated_single_signal)
            gan_loss_1 = (torch.mean((is_single_signal_r - torch.mean(is_single_signal_f) + target_real) ** 2) +
                torch.mean((is_single_signal_f - torch.mean(is_single_signal_r) - target_real) ** 2))/2

            #gan_loss_1 =criterion_GAN(is_single_signal, target_real )


            is_mixed_signal_r = self.discriminator_2(mixed_signal)
            is_mixed_signal_f = self.discriminator_2(new_mixed_signal)
            gan_loss_2 = (torch.mean((is_mixed_signal_r - torch.mean(is_mixed_signal_f) + target_real) ** 2) +
                torch.mean((is_mixed_signal_f - torch.mean(is_mixed_signal_r) - target_real) ** 2))/2

            #gan_loss_2 = criterion_GAN(is_mixed_signal, target_real )

            reconstructed_single_sinal = self.generator(new_mixed_signal)

            cycle_loss_1 = criterion_cycle(reconstructed_single_sinal, single_signal)
            # Total Loss
            g_cost  =  identity_loss_1*0.5 + (gan_loss_1 + gan_loss_2)  + 10*cycle_loss_1 
            g_cost.backward()
            self.optimizer_g.step()


            if self.validate and iter_indx%store_cost_every==0:
                self.discriminator_loss.append(d_loss_1.item())
                self.train_g_cost.append(g_cost.item())
                # validating
                self.disable_all()
                with torch.no_grad():
                    try:
                        val_data = next(val_set)
                    except StopIteration:
                        val_set = iter(self.val_loader)
                        val_data = next(val_set)
                    val_single = val_data[0].to(device)
                    val_mixed = val_data[1].to(device)
                    val_cost= criterion_GAN(self.discriminator_1(val_single), target_real) + criterion_GAN(self.discriminator_2(val_single), target_real)
                    self.val_g_cost.append(val_cost.item())
                    reconstructed_music = self.generator(val_mixed)
                    self.valid_reconstruction.append(F.mse_loss(reconstructed_music, val_single, reduction='sum').item())
                    

            if iter_indx%store_cost_every==0 :
                progress_updates = {'Reconstruction': str(self.valid_reconstruction[-1]),'Loss_D1': str(d_loss_1.item()), 'Loss_g':str(g_cost.item())}
                progress_bar.set_postfix(progress_updates)


            if iter_indx%progress_bar_step_iter_size==0:
                progress_bar.update()
            # lr decay 
            if decay_lr:
                decay = max(0.0,1.0-(iter_indx*1.0/n_iterations))
                # update the learning rate
                update_optimizer_lr(self.optimizer_d, lr_d, decay)
                update_optimizer_lr(self.optimizer_g, lr_g, decay)

            if (iter_indx%save_samples_every==0):
                with torch.no_grad():
                    fake = self.generator(fixed_mixed_signal).detach().cpu().numpy()
                save_samples(fake, iter_indx , prefix='predictions')
            
            if take_backup and iter_indx%backup_every_n_iters==0:
                saving_dict = {
                    'generator': self.generator.state_dict(),
                    'discriminator_1': self.discriminator_1.state_dict(),
                    'optimizer_d_1': self.optimizer_d_1.state_dict(), 
                    'discriminator_2': self.discriminator_2.state_dict(),
                    'optimizer_d_2': self.optimizer_d_2.state_dict(), 
                    'optimizer_g': self.optimizer_g.state_dict(),
                    'val_g_cost': self.val_g_cost,
                    'train_g_cost': self.train_g_cost,
                    'cyclic_loss': self.cyclic_loss,
                    'discriminator_loss': self.discriminator_loss,
                    'n_iterations': iter_indx
                }
                torch.save(saving_dict,gan_model_name)

In [None]:
import warnings
warnings.filterwarnings("ignore")

## Train

In [None]:
model = CycleGan(train_loader, val_loader)
gan_model_name = '/kaggle/input/test-bansuri-checkpoint/gan_cyclic_single_2disc_exp_musdb_1_wide_unpaired_ralsgan_4.tar'.format(model_prefix)
checkpoint = torch.load(gan_model_name, map_location='cpu')


model.generator.load_state_dict(checkpoint['generator'])
model.discriminator_1.load_state_dict(checkpoint['discriminator_1'])
model.discriminator_2.load_state_dict(checkpoint['discriminator_2'])
model.optimizer_d_1.load_state_dict(checkpoint['optimizer_d_1'])
model.optimizer_d_2.load_state_dict(checkpoint['optimizer_d_2'])
model.optimizer_g.load_state_dict(checkpoint['optimizer_g'])


model.val_g_cost = checkpoint['val_g_cost']
model.train_g_cost = checkpoint['train_g_cost']
model.cyclic_loss = checkpoint['cyclic_loss']
model.discriminator_loss = checkpoint['discriminator_loss']
n_iterations = checkpoint['n_iterations']

gan_model = model

In [None]:
# gan_model = CycleGan(train_loader,val_loader)
gan_model.train()

# **Evaluation**

In [None]:
!pip install mir_eval
!pip install museval

In [None]:
## Results_vocal_MUSDB - SIR 11.670795746936527 SAR 11.670795746936527 SDR 6.5125854585956136 ----1000000-iter
## REsults-vocal_nepali - SIR 11.603162618823603 SAR 11.603162618823603 SDR 7.324119863550323 ---- 100000 -iter
import mir_eval
import museval
import gc
def _any_source_silent(sources):
    """Returns true if the parameter sources has any silent first dimensions"""
    return np.any(np.all(np.sum(
        sources, axis=tuple(range(2, sources.ndim))) == 0, axis=1))

test_dataset = LMDBWavLoader('musdb_valid',True)   ## LMDBWavLoader('/kaggle/input/musdb-valid/musdb_valid/musdb_valid', True) 
test_data_len = len(test_dataset) 
eps = 1e-15
with torch.no_grad():
    # to get 1 min output
    reconstructed_signals = []
    mixed_signals = []
    original_signals = []
    out_sdr = []
    out_sir = []
    out_sar = []
    for i in tqdm(range(test_data_len)):
        #sample_rate = 20480
        val_data  = test_dataset[i] 
        mixed_signal = []
        wav_iter = list(audio_generator(val_data[1]))
        subsample_n_samples = len(wav_iter)
        for j in range(subsample_n_samples):
            mixed_signal.append(wav_iter[j])
        mixed_signal = torch.squeeze(torch.from_numpy(np.stack(mixed_signal,axis=0)).float()).to(device)
        mixed_signal = torch.unsqueeze(mixed_signal, dim=1)
        reconstructed = gan_model.generator(mixed_signal) 
        single_source = val_data[0]
        
        n_items_per_eval = 3
        data_len = (single_source.shape[-1]//n_items_per_eval) * n_items_per_eval
        reconstructed = torch.squeeze(reconstructed).detach().cpu().numpy().reshape(1,-1)[:,:data_len]

        reconstructed = reconstructed[:data_len].reshape( -1)
        single_source = single_source[:data_len].reshape( -1) 
        mixed_source = val_data[1][:data_len].reshape( -1)
        clean_inference = mixed_source - single_source
        predicted_inference = mixed_source - reconstructed
        #single_source = librosa.resample(single_source, sample_rate, window_length) 
        #reconstructed = librosa.resample(reconstructed, sample_rate, window_length)
        reference_music =single_source#, mixed_signal[:data_len].reshape(n_items_per_eval, -1) - single_source[:data_len].reshape(n_items_per_eval, -1)])
        estimates_music = reconstructed#, mixed_signal[:data_len].reshape(n_items_per_eval, -1) - reconstructed])
        del reconstructed
        del single_source
        del mixed_signal
        gc.collect()
        sdr_b,  sir_b, sar_b, _ =mir_eval.separation.bss_eval_sources_framewise(np.array([reference_music, clean_inference]), np.array([estimates_music, predicted_inference]))
        sdr, sir, sar = sdr_b, sir_b, sar_b
        #sdr_inter, sir_inter, sar_inter = sdr_b[1], sir_b[1], sar_b[1]
        out_sdr.append(np.mean(sdr[~np.isnan(sdr)]))
        out_sir.append(np.mean(sir[~np.isnan(sir)]))
        out_sar.append(np.mean(sir[~np.isnan(sar)]))
    sdr  = np.median(out_sdr)
    sir = np.median(out_sir)
    sar = np.median(out_sar)
    print('SIR {} SAR {} SDR {}'.format(str(sir),str(sar),str(sdr)))

In [None]:
from IPython.display import Audio

def preprocess(x):
    output_array = np.copy(x)
    audio_len = output_array.shape[-1]
    for i in range(window_length, audio_len, window_length):
        output_array[0,0,i-1] = 0
    return output_array
    
def predict_long_wav(wav_file_path, out_folder_name):
    print(wav_file_path)
    wav_data = load_wav(wav_file_path)
    wav_iter = list(audio_generator(wav_data))
    wav_iter = np.stack(wav_iter)
    wav_iter = torch.from_numpy(wav_iter).float().to(device)
    wav_iter = torch.unsqueeze(wav_iter,dim=1)
    reconstructed = gan_model.generator(wav_iter)
    x = reconstructed.view(1,1,-1).detach().cpu().numpy()
    x = preprocess(x)
    save_samples(x,  out_folder_name )
    return x
x = predict_long_wav('/kaggle/input/nepali-music-source-seperation/nepali_music_source_seperation/train/mixture/01_mixture.wav', 'out')

In [None]:
Audio(x[0][0], rate=16000)


In [None]:
# ## for testing the output of the model - storing the testing results
# for file in range(120, 150, 1):
#     predict_long_wav(f'/kaggle/input/to-delete-renamed/Newari_song/Newari_song/mixture/mixture_{file}.wav', f'out{file}')

In [None]:
torch.cuda.empty_cache()