In [None]:
# Define a weighted loss function based on the variable choice

# Loss weights
weighted_loss_type = 'exponential'  # Options: 'exponential' or 'square'
weight_min = 300                  # Minimal freq_bin index to weight. Avoids weighting the near zero "peak".

def weighted_mse(y_true, y_pred):
    # Calculate weights based on the weighted_loss_type
    if weighted_loss_type == 'exponential':
        weights = tf.exp(y_true)
    elif weighted_loss_type == 'square':
        weights = tf.square(y_true)
    else:
        raise ValueError("Invalid weighted_loss_type. Choose 'exponential' or 'square'.")

    # Dynamically compute the shape to handle variable batch sizes
    batch_size = tf.shape(weights)[0]
    time_steps = tf.shape(weights)[1]

    # Create a mask for indices below the weight_min
    mask = tf.range(time_steps) < weight_min  # Shape: (time_steps,)
    mask = tf.expand_dims(mask, axis=0)       # Shape: (1, time_steps)
    mask = tf.expand_dims(mask, axis=-1)      # Shape: (1, time_steps, 1)
    mask = tf.tile(mask, [batch_size, 1, 1])  # Shape: (batch_size, time_steps, 1)

    # Replace weights below the cutoff with the minimum weight value
    min_weight = tf.reduce_min(weights[:, weight_min:, :], axis=1, keepdims=True)
    weights = tf.where(mask, tf.broadcast_to(min_weight, tf.shape(weights)), weights)

    # Compute the weighted mean squared error
    loss = tf.reduce_mean(weights * tf.square(y_true - y_pred))
    return loss
