In [1]:
import tensorflow as tf

In [2]:

def tf_stft(x, y):
    # Ensure correct shape
    x = tf.squeeze(x, axis=-1)

    # Create a Spetrogram
    stft = tf.signal.stft(x, 
                          frame_length=1024, frame_step=512,
                          window_fn=tf.signal.hamming_window)
    # Take the same range of frequencies as in the pretrained model
    stft = stft[:, :, :int(stft.shape[2] * 5000 / (x.shape[1] // 2))]
    stft = tf.math.abs(stft)
    # Resize to the same shape as the input to the pretrained model
    stft = tf.repeat(stft, 16, axis=1)
    stft = tf.transpose(stft, perm=[2,1,0])[::-1]
    stft = tf.image.resize(stft, (369, 496))
    stft = tf.transpose(stft, perm=[2,0,1])

    # To 3-channel image (again, the same as the input to the pretrained model)
    stft = tf.expand_dims(stft, -1)
    stft = tf.image.grayscale_to_rgb(stft)
    
    return stft, tf.one_hot(y, 8)

train_ds, val_ds = tf.keras.utils.audio_dataset_from_directory(
    directory='train_audio_for_distill',
    batch_size=64,
    validation_split=0.2,
    seed=0,
    output_sequence_length=16000,
    subset='both')

# transform all audios to spectrograms
train_ds = train_ds.map(tf_stft)
val_ds = val_ds.map(tf_stft)

Found 115049 files belonging to 8 classes.
Using 92040 files for training.
Using 23009 files for validation.


In [3]:
# Load model but initialize with new weights
with tf.device('cpu'):
    pretrained_model = tf.keras.models.load_model('my_model', compile=False)
    
with tf.device('/gpu:2'):
    distilled_model = tf.keras.models.clone_model(pretrained_model)

    distilled_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=[tf.keras.metrics.CategoricalAccuracy()]
    )

    print(distilled_model.summary())

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 367, 494, 8)       224       
                                                                 
 normalization (Normalizatio  (None, 367, 494, 8)      17        
 n)                                                              
                                                                 
 max_pooling2d (MaxPooling2D  (None, 183, 247, 8)      0         
 )                                                               
                                                                 
 conv2d_1 (Conv2D)           (None, 181, 245, 16)      1168      
                                                                 
 normalization_1 (Normalizat  (None, 181, 245, 16)     33        
 ion)                                                            
                                                        

In [4]:
with tf.device('/gpu:2'):
    tf.debugging.set_log_device_placement(True)
    EPOCHS = 10
    history = distilled_model.fit(train_ds,
                        epochs=EPOCHS,
                        validation_data=val_ds,
                        verbose=1
                        )

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


In [5]:
!mkdir -p distilled_model
distilled_model.save('distilled_model/my_model')



INFO:tensorflow:Assets written to: distilled_model/my_model/assets


INFO:tensorflow:Assets written to: distilled_model/my_model/assets
