# End-to-End Tutorial: Training a Neural Network with Keras and Xbatcher

## Import Required Libraries

In [None]:
import matplotlib.pyplot as plt
import tensorflow as tf
import xarray as xr
from keras import layers, models, optimizers

import xbatcher as xb
import xbatcher.loaders.keras

In [None]:
# Open the dataset stored in Zarr format
ds = xr.open_dataset(
    's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',
    engine='zarr',
    chunks={},
    backend_kwargs={'storage_options': {'anon': True}},
)

## Define Batch Generators

In [None]:
# Define batch generators for features (X) and labels (y)
X_bgen = xb.BatchGenerator(
    ds['images'],
    input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},
    preload_batch=False,  # Load each batch dynamically
)
y_bgen = xb.BatchGenerator(
    ds['labels'], input_dims={'sample': 2000}, preload_batch=False
)

## Map Batches to a Keras-Compatible Dataset

In [None]:
# Use xbatcher's MapDataset to wrap the generators
dataset = xbatcher.loaders.keras.CustomTFDataset(X_bgen, y_bgen)

# Create a DataLoader using tf.data.Dataset
train_dataloader = tf.data.Dataset.from_generator(
    lambda: iter(dataset),
    output_signature=(
        tf.TensorSpec(shape=(2000, 1, 28, 28), dtype=tf.float32),  # Images
        tf.TensorSpec(shape=(2000,), dtype=tf.int64),  # Labels
    ),
).prefetch(3)  # Prefetch 3 batches to improve performance

In [None]:
## Visualize a Sample Batch

In [None]:
# Extract a batch from the DataLoader
for train_features, train_labels in train_dataloader.take(1):
    print(f'Feature batch shape: {train_features.shape}')
    print(f'Labels batch shape: {train_labels.shape}')

    img = train_features[0].numpy().squeeze()  # Extract the first image
    label = train_labels[0].numpy()
    plt.imshow(img, cmap='gray')
    plt.title(f'Label: {label}')
    plt.show()
    break

## Build a Simple Neural Network with Keras

In [None]:
# Define a simple feedforward neural network
model = models.Sequential(
    [
        layers.Flatten(input_shape=(1, 28, 28)),  # Flatten input images
        layers.Dense(128, activation='relu'),  # Fully connected layer with 128 units
        layers.Dense(10, activation='softmax'),  # Output layer for 10 classes
    ]
)

# Compile the model
model.compile(
    optimizer=optimizers.Adam(learning_rate=0.001),
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'],
)

# Display model summary
model.summary()

## Train the Model 

In [None]:
%%time

# Train the model for 5 epochs
epochs = 5

model.fit(
    train_dataloader,  # Pass the DataLoader directly
    epochs=epochs,
    verbose=1,  # Print progress during training
)

##  Visualize a Sample Prediction

In [None]:
# Visualize a prediction on a sample image
for train_features, train_labels in train_dataloader.take(1):
    img = train_features[0].numpy().squeeze()
    label = train_labels[0].numpy()
    predicted_label = tf.argmax(model.predict(train_features[:1]), axis=1).numpy()[0]

    plt.imshow(img, cmap='gray')
    plt.title(f'True Label: {label}, Predicted: {predicted_label}')
    plt.show()
    break

## Key Highlights 

- **Dynamic Batching**: Xbatcher and the MapDataset class allow for dynamic loading of batches, which reduces memory usage and speeds up data processing.
- **Prefetching**: The prefetch feature in `tf.data.Dataset` overlaps data loading with model training to minimize idle time.
- **Compatibility**: The pipeline works seamlessly with `keras.Model.fit`, simplifying training workflows.