In [1]:
# Set the working directory to the parent directory
import sys
sys.path.append('..')
sys.dont_write_bytecode = True

# Import relevant packages
import pandas as pd
import warnings
import numpy as np
from IPython.display import display
import tensorflow as tf

# Config
pd.set_option('display.max_columns', None) # Ensure all columns are displayed
warnings.filterwarnings("ignore")

In [2]:
# Read relevant files
X_train = pd.read_feather("../data/processed/X_train.feather")
X_train_validate = pd.read_feather("../data/processed/X_train_validate.feather")

In [None]:
def build_autoencoder(input_dim, hidden_dims, activation='relu'):
    encoder = tf.keras.Sequential()
    for h_dim in hidden_dims:
        encoder.add(tf.keras.layers.Dense(h_dim, activation=activation,
                                          kernel_initializer='glorot_uniform'))

    decoder = tf.keras.Sequential()
    for h_dim in reversed(hidden_dims[:-1]):
        decoder.add(tf.keras.layers.Dense(h_dim, activation=activation,
                                          kernel_initializer='glorot_uniform'))
    decoder.add(tf.keras.layers.Dense(input_dim, activation='sigmoid',
                                      kernel_initializer='glorot_uniform'))
    
    return encoder, decoder

def compute_loss(x, x_hat, model, lam, real_idx, binary_idx):
    
    # MSE Loss for real-valued features (1/2 factor)
    x_real = tf.gather(x, real_idx, axis=1)
    x_hat_real = tf.gather(x_hat, real_idx, axis=1)
    mse_loss = tf.reduce_sum(0.5 * tf.square(x_real - x_hat_real), axis=1)

    # Cross-Entropy Loss for binary-valued features
    x_bin = tf.gather(x, binary_idx, axis=1)
    x_hat_bin = tf.gather(x_hat, binary_idx, axis=1)
    eps = 1e-7  # For numerical stability
    x_hat_bin = tf.clip_by_value(x_hat_bin, eps, 1 - eps)
    ce_loss = tf.reduce_sum(
        -x_bin * tf.math.log(x_hat_bin) - (1 - x_bin) * tf.math.log(1 - x_hat_bin),
        axis=1
    )

    # Total reconstruction loss (average over batch)
    recon_loss = tf.reduce_mean(mse_loss + ce_loss)

    # Regularization (L2 norm of weights only)
    l2_reg = tf.add_n([
        tf.reduce_sum(tf.square(w))
        for w in model.trainable_variables
        if 'kernel' in w.name  # include weights only, not biases
    ])
    reg_term = (lam / 2.0) * l2_reg

    return recon_loss + reg_term

def train_autoencoder(x_train, x_val, learning_rate=1e-3, lam=1e-4, max_epochs=100, patience_limit=10,
                      hidden_dims=[64, 32], loss_type='mse'):

    # Model definition
    input_dim = x.shape[1]
    encoder, decoder = build_autoencoder(input_dim, hidden_dims)
    autoencoder = tf.keras.Sequential([encoder, decoder])

    optimizer = tf.keras.optimizers.Adam(learning_rate)
    
    best_val_loss = np.inf
    patience = 0

    for epoch in range(max_epochs):
        # Shuffle training data
        idx = np.random.permutation(len(x_train))
        x_train_shuffled = x_train[idx]

        # Per-instance gradient update
        for xi in x_train_shuffled:
            xi = tf.convert_to_tensor([xi], dtype=tf.float32)
            with tf.GradientTape() as tape:
                x_hat = autoencoder(xi, training=True)
                loss = compute_loss(xi, x_hat, autoencoder, lam, loss_type)

            gradients = tape.gradient(loss, autoencoder.trainable_variables)
            optimizer.apply_gradients(zip(gradients, autoencoder.trainable_variables))

        # Validation loss
        x_val_tensor = tf.convert_to_tensor(x_val, dtype=tf.float32)
        x_val_hat = autoencoder(x_val_tensor, training=False)
        val_loss = compute_loss(x_val_tensor, x_val_hat, autoencoder, lam, loss_type).numpy()

        print(f"Epoch {epoch+1}, Val Loss: {val_loss:.5f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience = 0
            best_weights = autoencoder.get_weights()
        else:
            patience += 1
            if patience >= patience_limit:
                print("Early stopping triggered.")
                break

    autoencoder.set_weights(best_weights)
    return encoder, decoder