In [None]:
import tensorflow as tf
import gc
import logging
import time

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def check_gpu():
    """Checks for available GPUs and prints info."""
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        logging.info(f"Found GPU(s): {len(gpus)}")
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
            logging.info(f"  {gpu.name} - Memory Growth enabled.")
        return True
    else:
        logging.warning("No GPU found. Running on CPU. Batch size limit might be very high (system RAM).")
        return False

def try_batch_size(model, optimizer, loss_fn, input_shape, output_shape, batch_size, dtype=tf.float32):
    """
    Attempts a single training step with the given batch size.

    Args:
        model: The tf.keras.Model to test.
        optimizer: The tf.keras.optimizers.Optimizer.
        loss_fn: The tf.keras.losses.Loss function.
        input_shape: Shape of a single input sample (excluding batch dimension).
        output_shape: Shape of a single output sample (excluding batch dimension).
        batch_size: The batch size to attempt.
        dtype: Data type for tensors (e.g., tf.float32, tf.float16).

    Returns:
        True if the batch size works, False if an OOM error occurs.
    """
    tf.keras.backend.clear_session() # Clear previous graphs/ops
    gc.collect() # Force garbage collection

    # Add batch dimension to shapes
    batch_input_shape = (batch_size,) + input_shape
    batch_output_shape = (batch_size,) + output_shape

    # Generate dummy data for this batch size
    # Using tf.zeros might be slightly faster and uses similar memory
    try:
        logging.debug(f"Generating data for batch size {batch_size}...")
        # Use tf.Variable to ensure data stays on the target device (GPU)
        x = tf.Variable(tf.zeros(batch_input_shape, dtype=dtype))
        y_true = tf.Variable(tf.zeros(batch_output_shape, dtype=dtype))
        logging.debug("Data generated.")

        logging.debug(f"Attempting train step with batch size {batch_size}...")
        with tf.GradientTape() as tape:
            # Ensure model variables are created (important for first run)
            if not model.built:
               model.build(batch_input_shape)
               logging.info(f"Model built with input shape: {batch_input_shape}")

            y_pred = model(x, training=True)
            loss = loss_fn(y_true, y_pred)

        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # Clean up tensors explicitly to help memory release
        del x, y_true, y_pred, loss, grads
        gc.collect()
        logging.debug(f"Batch size {batch_size} successful.")
        return True

    except tf.errors.ResourceExhaustedError as e:
        # Clean up tensors explicitly after OOM
        del x, y_true # Other variables might not exist if OOM happened early
        gc.collect()
        logging.warning(f"OOM error with batch size {batch_size}: {e}")
        return False
    except Exception as e:
         # Clean up tensors explicitly after other errors
        try:
            del x, y_true, y_pred, loss, grads
        except NameError:
            pass # Ignore if variables weren't created before the error
        gc.collect()
        logging.error(f"An unexpected error occurred with batch size {batch_size}: {e}")
        # Treat other errors like OOM for safety, as they might be memory related too
        return False


def find_max_batch_size(
    model_builder,
    optimizer_builder,
    loss_fn_builder,
    input_shape,
    output_shape,
    dtype=tf.float32,
    initial_batch_size=1,
    max_search_doubling=16 # Limit exponential search (2**(max_search_doubling)*initial)
    ):
    """
    Finds the maximum batch size that fits in memory.

    Args:
        model_builder: A function that returns a new instance of the Keras model.
        optimizer_builder: A function that returns a new instance of the optimizer.
        loss_fn_builder: A function that returns a new instance of the loss function.
        input_shape: Shape of a single input sample (excluding batch dimension).
        output_shape: Shape of a single output sample (excluding batch dimension).
        dtype: Data type for tensors (tf.float32, tf.float16).
        initial_batch_size: The starting batch size for the search.
        max_search_doubling: How many times to double the batch size in exponential search.

    Returns:
        The estimated maximum batch size, or 0 if even the initial size fails.
    """
    check_gpu()
    logging.info(f"Starting search for max batch size. Input: {input_shape}, Output: {output_shape}, Dtype: {dtype}")

    batch_size = initial_batch_size
    last_successful_batch_size = 0
    oom_occurred = False
    doubling_step = 0

    # --- Exponential Search Phase ---
    logging.info("--- Starting Exponential Search Phase ---")
    while doubling_step < max_search_doubling:
        logging.info(f"Trying batch size: {batch_size}")
        # Rebuild model and optimizer for clean state (important for memory)
        model = model_builder()
        optimizer = optimizer_builder()
        loss_fn = loss_fn_builder()

        # Special handling for first model build
        if not model.built:
           try:
               model.build((batch_size,) + input_shape)
               logging.info(f"Model built with input shape: {(batch_size,) + input_shape}")
           except Exception as e:
               logging.error(f"Failed to build model even with batch size {batch_size}: {e}")
               return 0 # Cannot proceed if model build fails

        if try_batch_size(model, optimizer, loss_fn, input_shape, output_shape, batch_size, dtype):
            last_successful_batch_size = batch_size
            batch_size *= 2 # Double for next attempt
            doubling_step += 1
            # Clean up before next iteration
            del model, optimizer, loss_fn
            gc.collect()
        else:
            oom_occurred = True
            # Clean up the failed attempt's resources
            del model, optimizer, loss_fn
            gc.collect()
            break # OOM occurred, move to binary search

        # Add a small delay to allow GPU memory to stabilize if needed
        time.sleep(0.5)


    if not oom_occurred:
        logging.warning(f"Exponential search completed {max_search_doubling} doublings without OOM. Max tested batch size: {last_successful_batch_size}. GPU memory might be very large or model/data small.")
        # If it never failed, the last successful one is our best guess within the search limit
        return last_successful_batch_size

    # --- Binary Search Phase ---
    logging.info("--- Starting Binary Search Phase ---")
    lower_bound = last_successful_batch_size
    upper_bound = batch_size # This is the size that failed

    # Ensure there's a gap to search
    if lower_bound >= upper_bound -1 :
         logging.info(f"Binary search range too small ({lower_bound} to {upper_bound}). Result is {lower_bound}.")
         return lower_bound

    logging.info(f"Binary search range: ({lower_bound}, {upper_bound})")

    while lower_bound < upper_bound - 1:
        test_batch_size = (lower_bound + upper_bound) // 2
        if test_batch_size == lower_bound: # Avoid infinite loop if integer division doesn't progress
             break

        logging.info(f"Trying batch size: {test_batch_size}")

        # Rebuild model and optimizer for clean state
        model = model_builder()
        optimizer = optimizer_builder()
        loss_fn = loss_fn_builder()

        if try_batch_size(model, optimizer, loss_fn, input_shape, output_shape, test_batch_size, dtype):
            lower_bound = test_batch_size # This size worked, try higher
            del model, optimizer, loss_fn
            gc.collect()
        else:
            upper_bound = test_batch_size # This size failed, try lower
            del model, optimizer, loss_fn
            gc.collect()

        time.sleep(0.5) # Optional delay

    logging.info(f"Binary search finished. Lower bound: {lower_bound}, Upper bound: {upper_bound}")
    return lower_bound # The largest size that succeeded

# =======================================================
# ===== USER CONFIGURATION ==============================
# =======================================================

# --- 1. Define Input/Output Shapes (excluding batch dimension) ---
# Example for image classification (e.g., MNIST/CIFAR)
# INPUT_SHAPE = (28, 28, 1) # Example: MNIST image shape
# OUTPUT_SHAPE = (10,)       # Example: 10 classes (one-hot encoded expected by CategoricalCrossentropy)

# Example for a simple MLP
INPUT_SHAPE = (784,)      # Example: Flattened MNIST
OUTPUT_SHAPE = (10,)       # Example: 10 classes

# Example for sequence data (e.g., NLP)
# INPUT_SHAPE = (128,)       # Example: Sequence length 128 (input IDs)
# OUTPUT_SHAPE = (128, 50)   # Example: Sequence length 128, 50 output features per token

# --- 2. Define Data Type ---
# Use tf.float16 for mixed precision (requires compatible GPU and potentially loss scaling)
# Use tf.float32 for standard precision
DATA_TYPE = tf.float32

# --- 3. Define Model Builder Function ---
# Replace this with YOUR model architecture
def build_my_model():
    # Example: Simple MLP
    model = tf.keras.Sequential([
        tf.keras.layers.InputLayer(input_shape=INPUT_SHAPE), # Use InputLayer for clarity
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dropout(0.2),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(OUTPUT_SHAPE[0], activation='softmax') # Adjust activation based on your task
    ])
    # Example: Simple CNN for images (use INPUT_SHAPE=(H, W, C))
    # model = tf.keras.Sequential([
    #     tf.keras.layers.InputLayer(input_shape=INPUT_SHAPE),
    #     tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
    #     tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    #     tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
    #     tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    #     tf.keras.layers.Flatten(),
    #     tf.keras.layers.Dropout(0.5),
    #     tf.keras.layers.Dense(OUTPUT_SHAPE[0], activation="softmax"),
    # ])
    return model

# --- 4. Define Optimizer Builder Function ---
def build_my_optimizer():
    # Common choice: Adam
    # return tf.keras.optimizers.Adam(learning_rate=1e-3)
    # Another option: SGD
     return tf.keras.optimizers.SGD(learning_rate=1e-2)


# --- 5. Define Loss Function Builder Function ---
def build_my_loss_function():
    # For classification with one-hot encoded labels
    # return tf.keras.losses.CategoricalCrossentropy()
    # For classification with integer labels
    # return tf.keras.losses.SparseCategoricalCrossentropy()
    # For regression
    return tf.keras.losses.MeanSquaredError()


# --- 6. Set Initial Batch Size for Search ---
# Start with 1 for safety, or a power of 2 you suspect might work
INITIAL_BATCH_SIZE = 1

# =======================================================
# ===== EXECUTION =======================================
# =======================================================

if __name__ == "__main__":
    max_bs = find_max_batch_size(
        model_builder=build_my_model,
        optimizer_builder=build_my_optimizer,
        loss_fn_builder=build_my_loss_function,
        input_shape=INPUT_SHAPE,
        output_shape=OUTPUT_SHAPE,
        dtype=DATA_TYPE,
        initial_batch_size=INITIAL_BATCH_SIZE
    )

    if max_bs > 0:
        logging.info(f"\nEstimated Maximum Batch Size: {max_bs}")
        logging.info("Note: This is an estimate based on a single training step.")
        logging.info("Actual training might have slightly different memory requirements.")
        logging.info(f"Consider using a slightly smaller batch size (e.g., {max_bs - (max_bs % 2 if max_bs > 1 else 0)}) for stability.") # Suggest slightly smaller round number
    else:
        logging.error("\nCould not find a working batch size, even the initial size failed.")
        logging.error("Check model complexity, input size, available GPU memory (use nvidia-smi), and data type (try float16 if applicable).")



**How to Use:**

1.  **Install TensorFlow:** If you haven't already: `pip install tensorflow` (or `tensorflow-gpu` if you manage CUDA manually, though the standard `tensorflow` package usually includes GPU support now).
2.  **Configure:**
    * Modify the `INPUT_SHAPE` and `OUTPUT_SHAPE` tuples to match *your* data (excluding the batch dimension).
    * Set the `DATA_TYPE` (usually `tf.float32`, consider `tf.float16` for mixed-precision if your GPU supports it - check compatibility).
    * Replace the example model architecture inside the `build_my_model` function with *your* actual model definition. **Crucially, ensure it returns a new model instance each time it's called.**
    * Adjust `build_my_optimizer` and `build_my_loss_function` to return instances of the optimizer and loss you intend to use.
    * You can optionally change `INITIAL_BATCH_SIZE`.
3.  **Run:** Execute the Python script (`python your_script_name.py`).
4.  **Observe:** The script will log its progress, trying increasing batch sizes until it hits an OOM error, then performing a binary search.
5.  **Result:** The script will print the estimated maximum batch size found.

**Important Considerations:**

* **GPU Memory:** Ensure no other processes are consuming significant GPU memory while you run the script (check with `nvidia-smi` in your terminal if you have an NVIDIA GPU).
* **Memory Fragmentation:** Sometimes, even if theoretically enough memory exists, fragmentation can prevent allocation. Restarting the Python kernel or even the machine can sometimes help.
* **Mixed Precision (`tf.float16`):** Using `tf.float16` roughly halves the memory needed for activations and data but requires a GPU with Tensor Core support (Volta architecture or newer). You might also need to use loss scaling with your optimizer (`tf.keras.mixed_precision.LossScaleOptimizer`) for numerical stability during actual training.
* **Model Complexity:** Larger models (more layers, wider layers, larger kernels) consume more memory.
* **Estimate vs. Reality:** This script simulates one training step. Real training involves data loading pipelines, callbacks, validation steps, etc., which might slightly alter memory usage. It's often wise to use a batch size slightly *smaller* than the absolute maximum found for stability.
* **CPU:** If run on a CPU, the limit will be system RAM, which is usually much less restrictive than GPU VRAM. The script will likely report a very large batch size.