In [None]:
import tensorflow as tf
import tensorflow_io as tfio
import os
import numpy as np
import matplotlib.pyplot as plt
import time
import cv2
from tensorflow.keras import layers
from keras.models import Model, Sequential
from tensorflow.keras.applications import VGG19
from keras.layers import Dense, Conv2D, Flatten, BatchNormalization, LeakyReLU
from keras.layers import Conv2DTranspose, Dropout, ReLU, Input, Concatenate, ZeroPadding2D
from keras.optimizers import Adam
from keras.utils import plot_model
from IPython.display import Audio
import librosa
import random
import matplotlib.pyplot as plt

In [None]:
import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) == 0:
    print("No GPU devices found. Make sure your GPU is properly installed and configured.")
else:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)
    print("GPU configured successfully.")

In [None]:
import tensorflow as tf

# Check available GPUs
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    # Restrict TensorFlow to only allocate GPU memory growth
    try:
        tf.config.experimental.set_memory_growth(gpus[0], True)
        print("GPU memory growth set to True")
    except RuntimeError as e:
        print(e)

In [None]:
import tensorflow as tf

# Check if TensorFlow can access GPU
num_gpus = len(tf.config.experimental.list_physical_devices('GPU'))

if num_gpus > 0:
    print("Num GPUs Available: ", num_gpus)
    print("TensorFlow is using GPU.")
else:
    print("No GPU available. TensorFlow is using CPU.")

In [None]:
def load(audio_file):
    audio = tf.io.read_file(audio_file)
    audio , sr= tf.audio.decode_wav(audio)
    audio = tf.squeeze(audio, axis=[-1])
    audio = tf.cast(audio,tf.float32)
    return audio


# def load(audio_file):
#     audio = tfio.audio.AudioIOTensor(audio_file).to_tensor()
#     audio = tf.squeeze(audio, axis=[-1])
#     audio = tf.cast(audio,tf.float32)/32768.0
#     return audio

In [None]:
audio = load('/kaggle/input/music-for-gan/music/val/music_144.wav')
print(audio.shape, type(audio))
Audio(audio.numpy().reshape([2646000]), rate=44100)

In [None]:
bansuri_audio = audio[:1323000]
print(bansuri_audio)
Audio(bansuri_audio.numpy().reshape([1323000]), rate = 44100)

In [None]:
mixture_audio = audio[1323000:]
Audio(mixture_audio.numpy().reshape([1323000]), rate=44100)

In [None]:
print(bansuri_audio.shape, type(bansuri_audio))
bansuri_spectogram = tfio.audio.spectrogram(bansuri_audio, nfft = 1022, window = 1022, stride=256)
print(bansuri_spectogram.shape)
plt.figure()
plt.imshow(tf.math.log(tf.transpose(bansuri_spectogram)).numpy())

In [None]:
print(audio.shape)
mixture_spectogram = tfio.audio.spectrogram(mixture_audio, nfft = 1022, window = 1022, stride=256)
print(mixture_spectogram.shape)
plt.figure()
plt.imshow(tf.math.log(tf.transpose(mixture_spectogram)).numpy())

In [None]:
seed1 = random.randint(0,2500)
seed2 = random.randint(0,2500)
bansuri_spectogram = tf.image.stateless_random_crop(bansuri_spectogram, (256,512), (seed1, seed2))
plt.figure()
plt.imshow(tf.math.log(tf.transpose(bansuri_spectogram)).numpy())


In [None]:
def load_train_audio(bansuri_audio_path):
    mixture = load(bansuri_audio_path)
    bansuri = mixture[:88200]
    mixture = mixture[88200:2*88200]
    bansuri = tfio.audio.spectrogram(bansuri, nfft = 1022, window = 1022, stride=256)
    mixture = tfio.audio.spectrogram(mixture, nfft = 1022, window = 1022, stride=256)
    seed1 = random.randint(0,2500)
    seed2 = random.randint(0,2500)
    bansuri = tf.image.stateless_random_crop(bansuri, (256,512), (seed1, seed2))
    mixture = tf.image.stateless_random_crop(mixture, (256,512), (seed1, seed2))
    bansuri = tf.reshape(bansuri,(256,512,1))
    mixture = tf.reshape(mixture,(256,512,1))
    return mixture, bansuri

In [None]:
BATCH_SIZE = 32

In [None]:
# create input pipeline
train_dataset = tf.data.Dataset.list_files('/kaggle/input/music-gan2/music2sec/kaggle/working/music/train/*.wav')
train_dataset = train_dataset.map(load_train_audio)
train_dataset = train_dataset.shuffle(10).batch(BATCH_SIZE)
train_dataset

In [None]:
validation_dataset = tf.data.Dataset.list_files('/kaggle/input/music-gan2/music2sec/kaggle/working/music/val/*.wav')
validation_dataset = validation_dataset.map(load_train_audio)
validation_dataset = validation_dataset.batch(BATCH_SIZE)
validation_dataset

In [None]:
from keras import Input, Model
from keras.layers import Conv2D, Dropout, BatchNormalization, LeakyReLU, Conv2DTranspose, Activation, Concatenate, Multiply
from tensorflow.keras.utils import plot_model

def unet(inputs=Input((256, 512, 1))):
    conv1 = Conv2D(64, 5, strides=2, padding='same')(inputs)
    conv1 = BatchNormalization(axis=-1)(conv1)
    conv1 = LeakyReLU(alpha=0.2)(conv1)

    conv2 = Conv2D(128, 5, strides=2, padding='same')(conv1)
    conv2 = BatchNormalization(axis=-1)(conv2)
    conv2 = LeakyReLU(alpha=0.2)(conv2)

    conv3 = Conv2D(256, 5, strides=2, padding='same')(conv2)
    conv3 = BatchNormalization(axis=-1)(conv3)
    conv3 = LeakyReLU(alpha=0.2)(conv3)

    conv4 = Conv2D(512, 5, strides=2, padding='same')(conv3)
    conv4 = BatchNormalization(axis=-1)(conv4)
    conv4 = LeakyReLU(alpha=0.2)(conv4)

    conv5 = Conv2D(1024, 5, strides=2, padding='same')(conv4)
    conv5 = BatchNormalization(axis=-1)(conv5)
    conv5 = LeakyReLU(alpha=0.2)(conv5)

    conv6 = Conv2D(1024, 5, strides=2, padding='same')(conv5)
    conv6 = BatchNormalization(axis=-1)(conv6)
    conv6 = LeakyReLU(alpha=0.2)(conv6)

    deconv7 = Conv2DTranspose(1024, 5, strides=2, padding='same')(conv6)
    deconv7 = BatchNormalization(axis=-1)(deconv7)
    deconv7 = Dropout(0.5)(deconv7)
    deconv7 = Activation('relu')(deconv7)

    deconv8 = Concatenate(axis=-1)([deconv7, conv5])
    deconv8 = Conv2DTranspose(512, 5, strides=2, padding='same')(deconv8)
    deconv8 = BatchNormalization(axis=-1)(deconv8)
    deconv8 = Dropout(0.5)(deconv8)
    deconv8 = Activation('relu')(deconv8)

    deconv9 = Concatenate(axis=-1)([deconv8, conv4])
    deconv9 = Conv2DTranspose(256, 5, strides=2, padding='same')(deconv9)
    deconv9 = BatchNormalization(axis=-1)(deconv9)
    deconv9 = Dropout(0.5)(deconv9)
    deconv9 = Activation('relu')(deconv9)

    deconv10 = Concatenate(axis=-1)([deconv9, conv3])
    deconv10 = Conv2DTranspose(128, 5, strides=2, padding='same')(deconv10)
    deconv10 = BatchNormalization(axis=-1)(deconv10)
    deconv10 = Activation('relu')(deconv10)

    deconv11 = Concatenate(axis=-1)([deconv10, conv2])
    deconv11 = Conv2DTranspose(64, 5, strides=2, padding='same')(deconv11)
    deconv11 = BatchNormalization(axis=-1)(deconv11)
    deconv11 = Activation('relu')(deconv11)

    deconv12 = Concatenate(axis=-1)([deconv11, conv1])
    deconv12 = Conv2DTranspose(1, 5, strides=2, padding='same')(deconv12)
    deconv12 = Activation('relu')(deconv12)
    deconv12 = BatchNormalization(axis=-1)(deconv12)

    output = Multiply()([deconv12, inputs])
    return Model(inputs=inputs, outputs=output)


# if __name__ == '__main__':
inputs = Input((256, 512, 1))
gen = unet(inputs)
gen.summary()
plot_model(gen, to_file='/kaggle/working/model_plot.png', show_shapes=True, show_layer_names=True)

In [None]:
# downsample block
def downsample(filters, size,dropout=False, batchnorm = True):
    init = tf.random_normal_initializer(0.,0.02)
    result = Sequential()
    result.add(Conv2D(filters, size, strides = 2, padding = "same", kernel_initializer = init, use_bias = False))
    if batchnorm == True:
        result.add(BatchNormalization())
    if dropout == True :
        result.add(Dropout(0.3))
    result.add(LeakyReLU())
    return result

In [None]:
def discriminator():
    init = tf.random_normal_initializer(0., 0.02)

    inp = Input(shape = [256, 512, 1], name = "mixture")
    tar = Input(shape = [256, 512, 1], name = "seperated")
    x = Concatenate()([inp, tar])
    down1 = downsample(32,4,False)(x)
    down2 = downsample(62, 4)(down1)
    down3 = downsample(128, 4)(down2)

    zero_pad1 = ZeroPadding2D()(down3)
    conv = Conv2D(256, 4, strides = 1, kernel_initializer = init, use_bias = False)(zero_pad1)
    leaky_relu = LeakyReLU()(conv)
    zero_pad2 = ZeroPadding2D()(leaky_relu)
    last = Conv2D(1, 4, strides = 1, kernel_initializer=init)(zero_pad2)
    return Model(inputs = [inp, tar], outputs = last)

In [None]:
disc = discriminator()
disc.summary()
plot_model(disc, show_shapes=True, dpi = 64)

In [None]:
from keras.losses import BinaryCrossentropy
loss_function = BinaryCrossentropy(from_logits=True)

In [None]:
def generator_loss(disc_generated_output,input_, gen_output, target):
    gan_loss = loss_function(tf.ones_like(disc_generated_output), disc_generated_output)
    l1_loss = tf.reduce_mean(tf.abs(target - gen_output))
    total_gen_loss = gan_loss + 100 * l1_loss 
    return total_gen_loss, gan_loss, l1_loss

In [None]:
def discriminator_loss(disc_real_output, disc_generated_output):
    real_loss = loss_function(tf.ones_like(disc_real_output), disc_real_output)
    generated_loss = loss_function(tf.zeros_like(disc_generated_output), disc_generated_output)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss, real_loss, generated_loss

In [None]:
generator_optimizer = Adam(lr= 5e-4, beta_1=0.5)
discriminator_optimizer = Adam(lr = 5e-4, beta_1=0.5)

In [None]:
@tf.function
def train_step(mixture, target, epoch, training_discriminator):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = gen(mixture, training=True)
        disc_real_output = disc([mixture, target], training=True)
        disc_generated_output = disc([mixture, gen_output], training=True)
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output,mixture, gen_output, target)
        disc_loss, disc_real_loss, disc_generated_loss = discriminator_loss(disc_real_output, disc_generated_output)

        if epoch > 8:
            if training_discriminator:
                discriminator_gradients = disc_tape.gradient(disc_loss, disc.trainable_variables)
                discriminator_optimizer.apply_gradients(zip(discriminator_gradients, disc.trainable_variables))
            else:
                generator_gradients = gen_tape.gradient(gen_total_loss, gen.trainable_variables)
                generator_optimizer.apply_gradients(zip(generator_gradients, gen.trainable_variables))
        else:
            generator_gradients = gen_tape.gradient(gen_total_loss, gen.trainable_variables)
            discriminator_gradients = disc_tape.gradient(disc_loss, disc.trainable_variables)
            generator_optimizer.apply_gradients(zip(generator_gradients, gen.trainable_variables))
            discriminator_optimizer.apply_gradients(zip(discriminator_gradients, disc.trainable_variables))

        return gen_total_loss, gen_gan_loss, gen_l1_loss, disc_loss,disc_real_loss, disc_generated_loss

In [None]:
@tf.function
def validation_step(mixture, target):
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        gen_output = gen(mixture, training=False)
        disc_real_output = disc([mixture, target], training=False)
        disc_generated_output = disc([mixture, gen_output], training=False)
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output,mixture, gen_output, target)
        disc_loss, disc_real_loss, disc_generated_loss = discriminator_loss(disc_real_output, disc_generated_output)
        return gen_total_loss, gen_gan_loss, gen_l1_loss, disc_loss,disc_real_loss, disc_generated_loss

In [None]:
each_epoch_train_generator_losses = []
each_epoch_train_generator_gan_losses = []
each_epoch_train_generator_l1_losses = []
each_epoch_train_discriminator_losses = []
each_epoch_train_discriminator_real_losses = []
each_epoch_train_discriminator_generated_losses = []
each_epoch_test_generator_losses = []
each_epoch_test_generator_gan_losses = []
each_epoch_test_generator_l1_losses = []
each_epoch_test_discriminator_losses = []
each_epoch_test_discriminator_real_losses = []
each_epoch_test_discriminator_generated_losses = []
train_generator_losses = []
train_generator_gan_losses = []
train_generator_l1_losses = []
train_discriminator_losses = []
train_discriminator_real_losses = []
train_discriminator_generated_losses = []
test_generator_losses = []
test_generator_gan_losses = []
test_generator_l1_losses = []
test_discriminator_losses = []
test_discriminator_real_losses = []
test_discriminator_generated_losses =[]
each_epoch_all_records = []
all_records = []
def fit(train_ds, epochs, test_ds, gen_model):
    global each_epoch_train_generator_losses, each_epoch_train_generator_gan_losses, each_epoch_train_generator_l1_losses, each_epoch_train_discriminator_losses,each_epoch_train_discriminator_real_losses, each_epoch_train_discriminator_generated_losses
    global each_epoch_test_generator_losses, each_epoch_test_generator_gan_losses, each_epoch_test_generator_l1_losses, each_epoch_test_discriminator_losses, each_epoch_test_discriminator_real_losses, each_epoch_test_discriminator_generated_losses
    global train_generator_losses, train_generator_gan_losses, train_generator_l1_losses,train_discriminator_real_losses, train_discriminator_generated_losses, test_generator_losses, test_generator_gan_losses, test_generator_l1_losses,test_discriminator_losses, test_discriminator_real_losses, test_discriminator_generated_losses
    global all_records, each_epoch_all_records
    
    # Initialize counters and flags
    consecutive_epochs_high_loss = 0
    training_discriminator = True
    stop_training = False
    min_learning_rate = 1e-7  # Set your desired minimum learning rate
    time_start = time.time()
    
    for ep in range(epochs):
        epoch = ep
        start = time.time()
        train_count = 0
        test_count = 0

        print(f"Epoch {epoch}")

        for n, (input_, target) in train_ds.enumerate():
            if epoch > 10:
                train_gen_loss, train_gen_gan_loss, train_gen_l1_loss, train_disc_loss, train_disc_real_loss, train_disc_generated_loss = train_step(input_, target, epoch, training_discriminator=training_discriminator)

            else:
                # If epoch is less than or equal to 10, train both generator and discriminator
                train_gen_loss, train_gen_gan_loss, train_gen_l1_loss, train_disc_loss, train_disc_real_loss, train_disc_generated_loss  = train_step(input_, target, epoch, training_discriminator=None)

            each_epoch_train_generator_losses.append(train_gen_loss)
            each_epoch_train_generator_gan_losses.append(train_gen_gan_loss)
            each_epoch_train_generator_l1_losses.append(train_gen_l1_loss)
#             each_epoch_train_generator_l2_losses.append(train_gen_l2_loss)
#             each_epoch_train_generator_perceptual_losses.append(train_gen_perceptual_loss)
            each_epoch_train_discriminator_losses.append(train_disc_loss)
            each_epoch_train_discriminator_real_losses.append(train_disc_real_loss)
            each_epoch_train_discriminator_generated_losses.append(train_disc_generated_loss)
#             each_epoch_train_ssim.append(tf.reduce_mean(train_ssim.numpy()))
#             each_epoch_train_psnr.append(tf.reduce_mean(train_psnr.numpy()))
#             each_epoch_train_lpips.append(tf.reduce_mean(train_lpips.numpy()))
#             print(train_count)
            train_count += 1
#             print(n)
        #each_epoch_train_generator_losses = np.array(each_epoch_train_generator_losses)
        #print(each_epoch_train_discriminator_losses.shape())
        #print(each_epoch_train_generator_losses.shape())
        #print(each_epoch_train_ssim())
        print("Training Details")
        print("Generator-- total_loss:{:.5f} gan_loss:{:.5f} l1_loss:{:.5f} Discriminator-- total_loss:{:.5f} real_loss:{:.5f} generated_loss:{:.5f}".format(np.mean(each_epoch_train_generator_losses),np.mean(each_epoch_train_generator_gan_losses), np.mean(each_epoch_train_generator_l1_losses), np.mean(each_epoch_train_discriminator_losses), np.mean(each_epoch_train_discriminator_generated_losses), np.mean(each_epoch_train_discriminator_real_losses)))
        print("Time taken for epoch {} is {} sec".format(epoch + 1, time.time() - start))
        print(f"Number of iteration {train_count}")

        train_generator_losses.append(np.mean(each_epoch_train_generator_losses))
        train_generator_gan_losses.append(np.mean(each_epoch_train_generator_gan_losses))
        train_generator_l1_losses.append(np.mean(each_epoch_train_generator_l1_losses))
#         train_generator_l2_losses.append(np.mean(each_epoch_train_generator_l2_losses))
#         train_generator_perceptual_losses.append(np.mean(each_epoch_train_generator_perceptual_losses))
        train_discriminator_losses.append(np.mean(each_epoch_train_discriminator_losses))
        train_discriminator_real_losses.append(np.mean(each_epoch_train_discriminator_real_losses))
        train_discriminator_generated_losses.append(np.mean(each_epoch_train_discriminator_generated_losses))
#         training_ssim.append(np.mean(each_epoch_train_ssim))
#         training_psnr.append(np.mean(each_epoch_train_psnr))
#         training_lpips.append(np.mean(each_epoch_train_lpips))

        # Test
        for n, (input_, target) in test_ds.enumerate():
            test_gen_loss, test_gen_gan_loss, test_gen_l1_loss, test_disc_loss, test_disc_real_loss, test_disc_generated_loss, = validation_step(input_, target)
            test_count += 1

            each_epoch_test_generator_losses.append(test_gen_loss)
            each_epoch_test_generator_gan_losses.append(test_gen_gan_loss)
            each_epoch_test_generator_l1_losses.append(test_gen_l1_loss)
#             each_epoch_test_generator_l2_losses.append(test_gen_l2_loss)
#             each_epoch_test_generator_perceptual_losses.append(test_gen_perceptual_loss)
            each_epoch_test_discriminator_losses.append(test_disc_loss)
            each_epoch_test_discriminator_real_losses.append(test_disc_real_loss)
            each_epoch_test_discriminator_generated_losses.append(test_disc_generated_loss)
#             each_epoch_test_ssim.append(tf.reduce_mean(test_ssim.numpy()))
#             each_epoch_test_psnr.append(tf.reduce_mean(test_psnr.numpy()))
#             each_epoch_test_lpips.append(tf.reduce_mean(test_lpips.numpy()))
    
        print("Validation Details")
        print("Generator-- total_loss:{:.5f} gan_loss:{:.5f} l1_loss:{:.5f}  Discriminator-- total_loss:{:.5f} real_loss:{:.5f} generated_loss:{:.5f}".format(np.mean(each_epoch_test_generator_losses),np.mean(each_epoch_test_generator_gan_losses), np.mean(each_epoch_test_generator_l1_losses),np.mean(each_epoch_test_discriminator_losses), np.mean(each_epoch_test_discriminator_generated_losses), np.mean(each_epoch_test_discriminator_real_losses)))
        print("Time taken for epoch {} is {} sec".format(epoch+1, time.time() - start))
        print(f"Number of iteration {test_count}")
#         LAMBDA = LAMBDA * 1.006
        test_generator_losses.append(np.mean(each_epoch_test_generator_losses))
        test_generator_gan_losses.append(np.mean(each_epoch_test_generator_gan_losses))
        test_generator_l1_losses.append(np.mean(each_epoch_test_generator_l1_losses))
#         test_generator_l2_losses.append(np.mean(each_epoch_test_generator_l2_losses))
#         test_generator_perceptual_losses.append(np.mean(each_epoch_test_generator_perceptual_losses))
        test_discriminator_losses.append(np.mean(each_epoch_test_discriminator_losses))
        test_discriminator_real_losses.append(np.mean(each_epoch_test_discriminator_real_losses))
        test_discriminator_generated_losses.append(np.mean(each_epoch_test_discriminator_generated_losses))
#         testing_ssim.append(np.mean(each_epoch_test_ssim))
#         testing_psnr.append(np.mean(each_epoch_test_psnr))
#         testing_lpips.append(np.mean(each_epoch_test_lpips))

        each_epoch_all_records.append([np.mean(each_epoch_train_generator_losses),np.mean(each_epoch_train_generator_gan_losses), np.mean(each_epoch_train_generator_l1_losses), np.mean(each_epoch_train_discriminator_losses), np.mean(each_epoch_train_discriminator_generated_losses), np.mean(each_epoch_train_discriminator_real_losses), np.mean(each_epoch_test_generator_losses),np.mean(each_epoch_test_generator_gan_losses), np.mean(each_epoch_test_generator_l1_losses),np.mean(each_epoch_test_discriminator_losses), np.mean(each_epoch_test_discriminator_generated_losses), np.mean(each_epoch_test_discriminator_real_losses)])
        all_records.append(each_epoch_all_records)
        # emptying for next epoch
        each_epoch_test_ssim = []
        each_epoch_test_generator_losses = []
        each_epoch_test_generator_gan_losses = []
        each_epoch_test_generator_l1_losses = []
#         each_epoch_test_generator_l2_losses = []
#         each_epoch_test_generator_perceptual_losses = []
        each_epoch_test_discriminator_losses = []
        each_epoch_test_discriminator_real_losses = []
        each_epoch_test_discriminator_generated_losses = []
#         each_epoch_test_psnr = []
#         each_epoch_test_lpips = []
#         each_epoch_all_records = []
        #emptyling list for next iteration
        each_epoch_train_generator_losses = []
        each_epoch_train_generator_gan_losses = []
        each_epoch_train_generator_l1_losses = []
#         each_epoch_train_generator_l2_losses = []
#         each_epoch_train_generator_perceptual_losses = []
        each_epoch_train_discriminator_losses = []
        each_epoch_train_discriminator_real_losses = []
        each_epoch_train_discriminator_generated_losses = []
#         each_epoch_train_ssim = []
#         each_epoch_train_psnr = []
#         each_epoch_train_lpips = []
        if time.time() - time_start > 41000:
            stop_training = True
#         if epoch == 0 or epoch % 3 == 0:
#             # Function to display one random test image
#             # def display_random_test_image(images, labels, gen_model):
#             #     # Randomly select an index
#             #     random_test_image_idx = random.randint(0, len(images) - 1)

#             #     # Extract input and target from the selected index
#             #     random_test_input = images[random_test_image_idx]
#             #     random_test_target = labels[random_test_image_idx]

#             #     # Display the image
#             #     display_one_random_test_image(random_test_input, random_test_target, gen_model)

#             # Example usage
#             display_random_test_image(gen, '/kaggle/input/imagenet/imagenet/val/ILSVRC2012_val_00000081.JPEG', epoch+1)
        # Check if the difference between two consecutive epochs for the last 5 epochs is less than 1
        if ep > 10:
            # Check 1f the difference between two consecutive epochs for the last S epochs is less than 1 or 0.3
            last_epochs_losses = train_discriminator_losses[-5:] if training_discriminator else train_generator_losses[-5:]
            if all(((last_epochs_losses[i] - last_epochs_losses[i - 1])) < 0.01 if training_discriminator else ((last_epochs_losses[1] - last_epochs_losses[i - 1])) < 0.05 for i in range(1, 5)):
                consecutive_epochs_high_loss += 1
            else:
                consecutive_epochs_high_loss = 0
        
        
#         LAMBDA = LAMBDA*1.003
            # If the condition is met for 4 consecutive epochs, reduce the learning rate
        if consecutive_epochs_high_loss == 4:
            if training_discriminator:
                current_lr = generator_optimizer. learning_rate.numpy()
                new_lr = max (current_lr * 0.45, min_learning_rate)
                if new_lr > min_learning_rate:
                    generator_optimizer.learning_rate.assign(new_lr)
                    print (f"Reduced generator learning rate to {new_lr} at epoch {epoch + 1}")
                    # Switch training focus
                    training_discriminator = not training_discriminator
                else:
                    print(f"Generator learning rate already at the minimum. Stopping training.")
                    stop_training = True
            else:
                current_lr = discriminator_optimizer.learning_rate.numpy()
                new_lr = max (current_lr * 0.3, min_learning_rate)
                if new_lr > min_learning_rate:
                    discriminator_optimizer.learning_rate.assign(new_lr)
                    print(f"Reduced discriminator learning rate to {new_lr} at epoch {epoch + 1}.")
                    # Switch training focus
                    training_discriminator = not training_discriminator
                else:
                    print(f"Discriminator learning rate already at the minimum. Stopping training.")
                    stop_training = True

                consecutive_epochs_high_loss = 0
        # Check the flag variable
        if stop_training:
            break


In [None]:
fit(train_dataset, 350, validation_dataset,gen)

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_generator_losses, '--b', label='Generator Loss')
plt.plot(epochs, train_discriminator_losses, '-.r', label='Discriminator Loss')
plt.title('Generator and Discriminator Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/train_loss_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_generator_gan_losses, '--b', label='Generator GAN Loss')
plt.plot(epochs, train_discriminator_losses, '-.r', label='Discriminator Loss')
plt.title('Generator GAN and Discriminator Training Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/train_loss_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(test_generator_losses) + 1)

plt.plot(epochs, test_generator_losses, '--b', label='Generator Loss')
plt.plot(epochs, test_discriminator_losses, '-.r', label='Discriminator Loss')
plt.title('Generator and Discriminator Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/test_loss_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_generator_gan_losses, '--b', label='Train_Generator_GAN_Loss')
plt.plot(epochs, test_generator_gan_losses, '-.r', label='Validation_Generator_GAN_Loss')
plt.title('Generator GAN Loss for Training and Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/Gen_GAN_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_generator_l1_losses, '--b', label='Train_Generator_L1_Loss')
plt.plot(epochs, test_generator_l1_losses, '-.r', label='Validation_Generator_L1_Loss')
plt.title('Generator L1 Loss for Training and Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/Gen_L1_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_generator_losses, '--b', label='Train_Generator_Loss')
plt.plot(epochs, test_generator_losses, '-.r', label='Validation_Generator_Loss')
plt.title('Generator Loss for Training and Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/Gen_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_discriminator_losses, '--b', label='Train_Discriminator_Loss')
plt.plot(epochs, test_discriminator_losses, '-.r', label='Validation_Discriminator_Loss')
plt.title('Discriminator Loss for Training and Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/Disc_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_discriminator_real_losses, '--b', label='Train_Discriminator_Real_Loss')
plt.plot(epochs, test_discriminator_real_losses, '-.r', label='Validation_Discriminator_Real_Loss')
plt.title('Discriminator Real Loss for Training and Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/Disc_Real_plot_1.png')
plt.show()

In [None]:
epochs = range(1, len(train_generator_losses) + 1)

plt.plot(epochs, train_discriminator_generated_losses, '--b', label='Train_Discriminator_Generated_Loss')
plt.plot(epochs, test_discriminator_generated_losses, '-.r', label='Validation_Discriminator_Generated_Loss')
plt.title('Discriminator Generated Loss for Training and Validation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.savefig('/kaggle/working/Disc_Generated_plot_1.png')
plt.show()

In [None]:
import csv
with open('/kaggle/working/dataJan20.csv','w', newline='') as file:
    writer = csv.writer(file)
    writer.writerows(all_records)

In [None]:
os.mkdir('/kaggle/working/modelsFeb20')

In [None]:
gen.save('/kaggle/working/modelsFeb20/gen.h5')
gen.save_weights('/kaggle/working/modelsFeb20/gen_weight.keras')
disc.save('/kaggle/working/modelsFeb20/disc.h5')
disc.save_weights('/kaggle/working/modelsFeb20/disc_weight.keras')

In [None]:
# test_audio = load('/kaggle/input/music-for-gan/music/val/music_147.wav')
# Audio(audio.numpy().reshape([1323000]), rate=44100)

In [None]:
number_of_chunks = mixture_spectogram.shape[0]//256
target = np.zeros([mixture_spectogram.shape[0],512])
for i in range(number_of_chunks):
    START = i*256
    END = START + 256

    S_mix_new=mixture_spectogram[START:END, :]

    X=tf.reshape(S_mix_new, (1, 256, 512, 1))

    y=gen.predict(X, batch_size=32)
    target[START:END,:] = y.reshape(256,512)
S_mix_new=mixture_spectogram[-256:, :]
X=tf.reshape(S_mix_new,(1, 256, 512, 1))
y=gen.predict(X, batch_size=32)
target[-256:,:] = y.reshape(256,512)

In [None]:
y = tfio.audio.inverse_spectrogram(target.astype(np.float32),nfft=1022, window=1022, stride=256, iterations = 100)

In [None]:
Audio(y,rate=44100)