# 5.3 Train Neural Networks - Code Brief

Condensed reference for training neural networks with callbacks.

## Setup

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

## Callbacks

In [None]:
# Early Stopping
early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

# Model Checkpoint
checkpoint = ModelCheckpoint(
    filepath='best_model.keras',
    monitor='val_loss',
    save_best_only=True,
    verbose=1
)

# Learning Rate Reducer
reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=5,
    min_lr=0.00001,
    verbose=1
)

callbacks = [early_stopping, checkpoint, reduce_lr]

## Class Weights for Imbalanced Data

In [None]:
n_class_0 = (y_train == 0).sum()
n_class_1 = (y_train == 1).sum()
total = len(y_train)

class_weights = {
    0: total / (2 * n_class_0),
    1: total / (2 * n_class_1)
}

## Training

In [None]:
EPOCHS = 100
BATCH_SIZE = 32

history = model.fit(
    X_train, y_train,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    validation_data=(X_val, y_val),
    class_weight=class_weights,
    callbacks=callbacks,
    verbose=1
)

## Training History

In [None]:
# Access training history
train_loss = history.history['loss']
val_loss = history.history['val_loss']
train_auc = history.history['auc']
val_auc = history.history['val_auc']

print(f"Final epoch: {len(train_loss)}")
print(f"Final val_loss: {val_loss[-1]:.4f}")
print(f"Final val_auc: {val_auc[-1]:.4f}")

## Generate Predictions

In [None]:
# Probability predictions
y_pred_proba = model.predict(X_val).flatten()

# Class predictions (threshold 0.5)
y_pred = (y_pred_proba > 0.5).astype(int)

## Save Model

In [None]:
model.save('trained_model.keras')

## Key Concepts

| Concept | Description |
|:--------|:------------|
| **Epoch** | One pass through entire dataset |
| **Batch Size** | Samples per weight update |
| **Early Stopping** | Stop when no improvement |
| **Class Weights** | Handle imbalanced data |