<a href="https://colab.research.google.com/github/supertime1/Speech_Emotion_Recognition/blob/main/SER.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import os
os.chdir('C:/Users/57lzhang.US04WW4008/PycharmProjects/Speech_Emotion_Recognition')
from data_handler import *
from audio_processor import AudioProcessor
import tensorflow as tf
import librosa.display
import matplotlib.pyplot as plt
import sklearn
from model.model_util import *
from tensorflow.keras import Model
from tensorflow.keras.layers import Conv2D, Conv1D, BatchNormalization, Input, Add, Activation, \
    MaxPooling1D, Dropout, Flatten, TimeDistributed, Bidirectional, Dense, LSTM, ZeroPadding1D, \
    AveragePooling1D, GlobalAveragePooling1D, Concatenate, Permute, Dot, Multiply, RepeatVector, \
    Lambda, Average, GlobalAveragePooling2D, DepthwiseConv2D, MaxPooling2D, ZeroPadding2D

In [4]:
 ##to overwrite NCCL cross device communication as this is running in Windows
strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')


In [3]:
# import data generator
raw_data_path = 'raw_data'
train_ratio = 0.9
val_ratio = 0.2
block_span = 1 # second
stride_span = 30 # millisecond
res_freq = 16000
random_seed = 10

data_handler = DataHandler(raw_data_path, train_ratio, val_ratio, 
                        res_freq, block_span, stride_span, random_seed)

sample_freq = res_freq
slice_span = 16 # millisecond
overlap_ratio = 3/4
n_mels = 64
snr = 20
audio_processor = AudioProcessor(sample_freq, slice_span, overlap_ratio, n_mels, snr)

##Training

In [4]:
sample_data = np.random.rand((block_span*res_freq))
sample_mel, _ = audio_processor.mel_spectrogram(sample_data, 1)
sample_mel = np.expand_dims(sample_mel, -1)
input_shape = sample_mel.shape
print(input_shape)

batch_size = 64
epochs = 100
train_filenames, train_num_samples = data_handler.get_filenames('data/train')
val_filenames, val_num_samples = data_handler.get_filenames('data/val')

def preprocess_dataset(files):
  files_ds = tf.data.Dataset.from_tensor_slices(files)
  output_ds = files_ds.map(data_handler.get_waveform_and_label, num_parallel_calls=tf.data.AUTOTUNE)
  output_ds = output_ds.map(audio_processor.get_mel_tensor, num_parallel_calls=tf.data.AUTOTUNE)
  return output_ds

train_ds = preprocess_dataset(train_filenames)
val_ds = preprocess_dataset(val_filenames)
train_ds = train_ds.batch(batch_size)
val_ds = val_ds.batch(batch_size)
train_ds = train_ds.cache().prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.cache().prefetch(tf.data.AUTOTUNE)

## early stop
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss',
                                                patience=20,
                                                restore_best_weights=True)
## learning rate decay callback
lr_schedule = tf.keras.callbacks.LearningRateScheduler(decay)
callback_list = [early_stop, lr_schedule]

(64, 251, 1)
Number of total examples: 80508
Example file tensor: tf.Tensor(b'data\\train\\4\\03-01-05-02-01-02-16_3840.wav', shape=(), dtype=string)
Number of total examples: 9362
Example file tensor: tf.Tensor(b'data\\val\\2\\03-01-03-02-01-02-08_0.wav', shape=(), dtype=string)


In [None]:
model = MobileNet(input_shape=input_shape, classes=8)
model.summary()
model.compile(optimizer=tf.keras.optimizers.Adam(),
                loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
                metrics=[tf.keras.metrics.CategoricalAccuracy()])
history = model.fit(train_ds,
                    epochs=epochs,
                    validation_data=val_ds,
                    verbose=1,
                    callbacks=callback_list
                    )

##visualize the preprocessing

In [None]:
train_gen = train_data_generator()
testing_data = next(train_gen)[0]

spec = audio_processor.spectrogram(testing_data)
librosa.display.specshow(spec, sr=sample_freq, x_axis='time', y_axis='linear');
plt.colorbar();
plt.title('Spectrogram')
plt.show()


mel_spec = audio_processor.mel_spectrogram(testing_data)
librosa.display.specshow(mel_spec, sr=sample_freq, x_axis='time', y_axis='linear');
plt.colorbar();
plt.title('Mel Spectrogram')
plt.show()