In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Reshape, BatchNormalization, Dense
from sklearn.model_selection import train_test_split
from scipy import stats
from ncps.tf import LTC
from tensorflow.keras.callbacks import ModelCheckpoint
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

print("Mounting Google Drive... Done.")

# Function to scale input data
def scaled_in(matrix_spec):
    "Global scaling applied to noisy voice spectrograms (scale between -1 and 1)"
    matrix_spec = (matrix_spec + 46) / 50
    return matrix_spec

print("Defining scaled_in function... Done.")

# Function to scale output data
def scaled_ou(matrix_spec):
    "Global scaling applied to noise models spectrograms (scale between -1 and 1)"
    matrix_spec = (matrix_spec - 6) / 82
    return matrix_spec

print("Defining scaled_ou function... Done.")

# Load data
path_save_spectrogram = '/content/drive/MyDrive/npy/'  # Specify the path to the spectrogram directory

print("Loading data...")
X_in = np.load(os.path.join(path_save_spectrogram, 'noisy_voice_amp_db.npy'))
X_ou = np.load(os.path.join(path_save_spectrogram, 'voice_amp_db.npy'))
print("Data loaded successfully.")

# Model of noise to predict
X_ou = X_in - X_ou

# Check distribution
print("Checking distribution of input data...")
print(stats.describe(X_in.reshape(-1, 1)))
print(stats.describe(X_ou.reshape(-1, 1)))
print("Distribution check completed.")

# Scale input and output data
print("Scaling input and output data...")
X_in = scaled_in(X_in)
X_ou = scaled_ou(X_ou)
print("Data scaled successfully.")

# Check shape of spectrograms
print("Checking shape of spectrograms...")
print("Input shape:", X_in.shape)
print("Output shape:", X_ou.shape)
print("Shape check completed.")

# Reshape for training
print("Reshaping data for training...")
X_in = X_in[:, :, :]
X_in = X_in.reshape(X_in.shape[0], X_in.shape[1], X_in.shape[2], 1)
X_ou = X_ou[:, :, :]
X_ou = X_ou.reshape(X_ou.shape[0], X_ou.shape[1], X_ou.shape[2], 1)
print("Data reshaped successfully.")

# Split data into train and validation sets
print("Splitting data into train and validation sets...")
x_train, x_val, y_train, y_val = train_test_split(X_in, X_ou, test_size=0.10, random_state=42)
print("Data split completed.")

# Load the saved UNet model
# No UNet model in this case, since we're only using RNNs

# Define RNN model
print("Defining RNN model...")
ncp = LTC(32, 16)
rnn_model = tf.keras.Sequential()
rnn_model.add(Reshape((-1, 1)))  # Reshape to add a time dimension
rnn_model.add(ncp)
rnn_model.add(BatchNormalization())
rnn_model.add(Dense(128, activation='relu'))  # Adjust output size to match input size
rnn_model.add(Dense(128, activation='relu'))  # Adjust output size to match input size
rnn_model.add(Dense(128, activation='relu'))  # Adjust output size to match input size
rnn_model.add(Dense(1, activation='linear'))  # Adjust output size to match input size
print("RNN model defined successfully.")

# Compile the RNN model
print("Compiling the RNN model...")
rnn_model.compile(optimizer='Adam',
                  loss='mean_squared_error',  # Use mean squared error loss
                  metrics=['mae'])
print("RNN model compiled successfully.")

# Define filepath for saving the best model
model_checkpoint_path = 'best_model_rnn.keras'

# Define the ModelCheckpoint callback
checkpoint = ModelCheckpoint(model_checkpoint_path,
                             verbose=1,
                             monitor='val_loss',
                             save_best_only=True,
                             mode='auto')

# Train the RNN model
print("Training the RNN model...")
history = rnn_model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=20, batch_size=10, verbose=1, callbacks=[checkpoint])
print("Training completed.")

# Save model architecture to JSON file
print("Saving model architecture to JSON file...")
model_json = rnn_model.to_json()
with open('model_rnn.json', 'w') as json_file:
    json_file.write(model_json)
print("Model architecture saved successfully.")

# Save model weights
print("Saving model weights...")
rnn_model.save_weights('model_rnn_weights.keras')
print("Model weights saved successfully.")

# Plot training and validation loss
import matplotlib.pyplot as plt

loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)

plt.plot(epochs, loss, label='Training loss')
plt.plot(epochs, val_loss, label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

# Print epoch information and other variables
print("Epochs:", epochs)
print("Training Loss:", loss)
print("Validation Loss:", val_loss)
