In [None]:
!apt-get update -qq
!apt-get install -y libopenmpi-dev openmpi-bin openmpi-common
!pip install mpi4py

W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list entry misspelt?)
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
libopenmpi-dev is already the newest version (4.1.2-2ubuntu1).
openmpi-bin is already the newest version (4.1.2-2ubuntu1).
openmpi-bin set to manually installed.
openmpi-common is already the newest version (4.1.2-2ubuntu1).
openmpi-common set to manually installed.
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.
Collecting mpi4py
  Downloading mpi4py-4.0.3.tar.gz (466 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m466.3/466.3 kB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproj

In [None]:
import tensorflow as tf
from mpi4py import MPI

# Initialize MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

# Model definition (same for all nodes)
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(28,28,1)),
    tf.keras.layers.MaxPooling2D((2,2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Compile once (before training loop)
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

# Load and prepare data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train[..., tf.newaxis] / 255.0  # Add channel dim
x_test = x_test[..., tf.newaxis] / 255.0

def train(model, x_train, y_train):
    # Data splitting
    n = len(x_train)
    chunk_size = n // size
    start = rank * chunk_size
    end = (rank + 1) * chunk_size if rank != size-1 else n

    # Local training
    model.fit(x_train[start:end], y_train[start:end],
              epochs=1, batch_size=32, verbose=0)

    # Weight synchronization
    weights = model.get_weights()
    weights = comm.allreduce(weights, op=MPI.SUM)
    model.set_weights([w / size for w in weights])

# Training loop
epochs = 5
for epoch in range(epochs):
    train(model, x_train, y_train)

    # Rank 0 handles evaluation
    if rank == 0:
        test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
        print(f"Epoch {epoch+1}: Test accuracy = {test_acc:.4f}")


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Epoch 1: Test accuracy = 0.9733
Epoch 2: Test accuracy = 0.9787
Epoch 3: Test accuracy = 0.9803
Epoch 4: Test accuracy = 0.9826
Epoch 5: Test accuracy = 0.9820
