In [1]:
import tensorflow as tf
from tensorflow.keras.layers import LeakyReLU
import numpy as np

In [2]:
#Parameters
K_bar = 1 #Number of tracks generated by the generator
num_sequential_bars = 4 #For discriminator
num_tracks = 5 #For discriminator
batch_size = 32
skip_connections = True


In [3]:
#ENCODER (with skip connections)
encoder_input = tf.keras.Input(shape=(96,84,1))
encoder_pitchconv1 = tf.keras.layers.Conv2D(16,(1,12),(1,12),activation=LeakyReLU())(encoder_input)
encoder_bn1 = tf.keras.layers.BatchNormalization()(encoder_pitchconv1)
encoder_pitchconv2 = tf.keras.layers.Conv2D(16,(1,7),(1,7),activation=LeakyReLU())(encoder_bn1)
encoder_bn2 = tf.keras.layers.BatchNormalization()(encoder_pitchconv2)
encoder_timeconv1 = tf.keras.layers.Conv2D(16,(3,1),(3,1),activation=LeakyReLU())(encoder_bn2)
encoder_bn3 = tf.keras.layers.BatchNormalization()(encoder_timeconv1)
encoder_timeconv2 = tf.keras.layers.Conv2D(16,(2,1),(2,1),activation=LeakyReLU())(encoder_bn3)
encoder_bn4 = tf.keras.layers.BatchNormalization()(encoder_timeconv2)
encoder_timeconv3 = tf.keras.layers.Conv2D(16,(2,1),(2,1),activation=LeakyReLU())(encoder_bn4)
encoder_bn5 = tf.keras.layers.BatchNormalization()(encoder_timeconv3)
encoder_timeconv4 = tf.keras.layers.Conv2D(16,(2,1),(2,1),activation=LeakyReLU())(encoder_bn5)
encoder_bn6 = tf.keras.layers.BatchNormalization()(encoder_timeconv4)
encoder_output = tf.keras.layers.Flatten()(encoder_bn6)

decoder_input = tf.keras.layers.Reshape((4,1,16))(encoder_output)
decoder_timeconv1 = tf.keras.layers.Conv2DTranspose(16,(2,1),(2,1),activation=LeakyReLU())(decoder_input)
decoder_bn1 = tf.keras.layers.BatchNormalization()(decoder_timeconv1)
decoder_out1 = tf.keras.layers.add([decoder_bn1,encoder_bn5])
decoder_timeconv2 = tf.keras.layers.Conv2DTranspose(16,(2,1),(2,1),activation=LeakyReLU())(decoder_out1)
decoder_bn2 = tf.keras.layers.BatchNormalization()(decoder_timeconv2)
decoder_out2 = tf.keras.layers.add([decoder_bn2,encoder_bn4])
decoder_timeconv3 = tf.keras.layers.Conv2DTranspose(16,(2,1),(2,1),activation=LeakyReLU())(decoder_out2)
decoder_bn3 = tf.keras.layers.BatchNormalization()(decoder_timeconv3)
decoder_out3 = tf.keras.layers.add([decoder_bn3,encoder_bn3])
decoder_timeconv4 = tf.keras.layers.Conv2DTranspose(16,(3,1),(3,1),activation=LeakyReLU())(decoder_out3)
decoder_bn4 = tf.keras.layers.BatchNormalization()(decoder_timeconv4)
decoder_out4 = tf.keras.layers.add([decoder_bn4,encoder_bn2])
decoder_pitchconv1 = tf.keras.layers.Conv2DTranspose(16,(1,7),(1,7),activation=LeakyReLU())(decoder_out4)
decoder_bn5 = tf.keras.layers.BatchNormalization()(decoder_pitchconv1)
decoder_out5 = tf.keras.layers.add([decoder_bn5,encoder_bn1])
decoder_pitchconv2 = tf.keras.layers.Conv2DTranspose(1,(1,12),(1,12),activation='sigmoid')(decoder_out5)
decoder_output = tf.keras.layers.BatchNormalization()(decoder_pitchconv2)

In [4]:
#GENERATOR 1 (Need to Switch activation and BN sequence?)
#TODO: Double cehck! there is one more CNN layer here
generator_input = tf.keras.Input((1,1,128))
x = tf.keras.layers.Conv2DTranspose(1024,(2,1),(2,1),activation="relu")(generator_input)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(256,(2,1),(2,1),activation="relu")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(256,(2,1),(2,1),activation="relu")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(256,(2,1),(2,1),activation="relu")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(128,(2,1),(2,1),activation="relu")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(128,(3,1),(3,1),activation="relu")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(64,(1,7),(1,7),activation="relu")(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Conv2DTranspose(K_bar,(1,12),(1,12),activation="sigmoid")(x)
generator_output = tf.keras.layers.BatchNormalization()(x)


In [5]:
discriminator_paired_input1 = tf.keras.Input((num_sequential_bars,96,84,num_tracks))
discriminator_paired_input2 = tf.keras.Input((num_sequential_bars,96,84,num_tracks))
x = tf.keras.layers.concatenate([discriminator_paired_input1,discriminator_paired_input2])
x = tf.keras.layers.Conv3D(128,(2,1,1),(1,1,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(3,1,1),(1,1,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,1,12),(1,1,12),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,1,7),(1,1,7),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,2,1),(1,2,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,2,1),(1,2,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(256,(1,4,1),(1,4,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(512,(1,3,1),(1,2,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Flatten()(x)
discriminator_paired_output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

In [6]:
discriminator_input = tf.keras.Input((num_sequential_bars,96,84,num_tracks))
x = tf.keras.layers.Conv3D(128,(2,1,1),(1,1,1),activation=LeakyReLU())(discriminator_input)
x = tf.keras.layers.Conv3D(128,(3,1,1),(1,1,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,1,12),(1,1,12),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,1,7),(1,1,7),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,2,1),(1,2,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(128,(1,2,1),(1,2,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(256,(1,4,1),(1,4,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Conv3D(512,(1,3,1),(1,2,1),activation=LeakyReLU())(x)
x = tf.keras.layers.Flatten()(x)
discriminator_output = tf.keras.layers.Dense(1, activation='sigmoid')(x)

In [7]:
generator = tf.keras.Model(generator_input,generator_output)
generator.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 1, 1, 128)]       0         
_________________________________________________________________
conv2d_transpose_6 (Conv2DTr (None, 2, 1, 1024)        263168    
_________________________________________________________________
batch_normalization_12 (Batc (None, 2, 1, 1024)        4096      
_________________________________________________________________
conv2d_transpose_7 (Conv2DTr (None, 4, 1, 256)         524544    
_________________________________________________________________
batch_normalization_13 (Batc (None, 4, 1, 256)         1024      
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 8, 1, 256)         131328    
_________________________________________________________________
batch_normalization_14 (Batc (None, 8, 1, 256)         1024  

In [8]:
discriminator_paired = tf.keras.Model([discriminator_paired_input1,discriminator_paired_input2],discriminator_paired_output)
discriminator_paired.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 4, 96, 84, 5 0                                            
__________________________________________________________________________________________________
input_4 (InputLayer)            [(None, 4, 96, 84, 5 0                                            
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 4, 96, 84, 10 0           input_3[0][0]                    
                                                                 input_4[0][0]                    
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 3, 96, 84, 12 2688        concatenate[0][0]          

In [9]:
encoder = tf.keras.Model(encoder_input,encoder_output,name="encoder")
encoder.summary()
autoencoder= tf.keras.Model(encoder_input,decoder_output,name="autoencoder")
autoencoder.summary()

Model: "encoder"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 96, 84, 1)]       0         
_________________________________________________________________
conv2d (Conv2D)              (None, 96, 7, 16)         208       
_________________________________________________________________
batch_normalization (BatchNo (None, 96, 7, 16)         64        
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 96, 1, 16)         1808      
_________________________________________________________________
batch_normalization_1 (Batch (None, 96, 1, 16)         64        
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 32, 1, 16)         784       
_________________________________________________________________
batch_normalization_2 (Batch (None, 32, 1, 16)         64  

In [10]:
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# TRAINING STAGE
1. Train two autoencoders on orchestra and piano samples respectively.
2. Pre-train Generators with pretrain_discriminator, on orchestra and piano repsectively (pix2pix approach)
3. Train on unpaired with Cycle-GAN approach #we update G once every five updates of D and apply batch normalization only to G