In [5]:
import tensorflow as tf
from mpi4py import MPI

# Initialize MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

# Define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Load and preprocess the dataset
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0

# Expand dimensions to fit Conv2D input shape
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]

# Define the training function
def train(model, x_train, y_train, rank, size):
    n = len(x_train)
    chunk_size = n // size
    start = rank * chunk_size
    end = (rank + 1) * chunk_size if rank != size - 1 else n

    x_train_chunk = x_train[start:end]
    y_train_chunk = y_train[start:end]

    # Compile the model
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train the model
    model.fit(x_train_chunk, y_train_chunk, epochs=1, batch_size=32, verbose=0)

    # Evaluate on training chunk
    train_loss, train_acc = model.evaluate(x_train_chunk, y_train_chunk, verbose=0)

    # Average accuracy across all nodes
    train_acc = comm.allreduce(train_acc, op=MPI.SUM)

    return train_acc / size

# Training loop
epochs = 5
for epoch in range(epochs):
    train_acc = train(model, x_train, y_train, rank, size)

    # Evaluate on test set
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
    test_acc = comm.allreduce(test_acc, op=MPI.SUM)

    if rank == 0:
        print(f"Epoch {epoch + 1}: Train accuracy = {train_acc:.4f}, Test accuracy = {test_acc / size:.4f}")

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Epoch 1: Train accuracy = 0.9733, Test accuracy = 0.9720
Epoch 2: Train accuracy = 0.9847, Test accuracy = 0.9799
Epoch 3: Train accuracy = 0.9876, Test accuracy = 0.9816
Epoch 4: Train accuracy = 0.9843, Test accuracy = 0.9782
Epoch 5: Train accuracy = 0.9915, Test accuracy = 0.9834


In [4]:
pip install mpi4py

Collecting mpi4py
  Downloading mpi4py-4.0.3.tar.gz (466 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/466.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m460.8/466.3 kB[0m [31m17.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m466.3/466.3 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mpi4py
  Building wheel for mpi4py (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mpi4py: filename=mpi4py-4.0.3-cp311-cp311-linux_x86_64.whl size=4458272 sha256=8b51a5dfcf2c2c5c6919ef27e4d12b5e289e57f7526d2a66e32d57178939b444
  Stored in directory: /root/.cache/pip/wheels/5c/56/17/bf6ba37a