In [None]:
pip install ray

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import numpy as np
import ray
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

# Define the CNN model
def create_model():
    model = keras.Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(10, activation='softmax'))
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# Define the remote training function
@ray.remote
def train_layer(X, y, weights):
    model = create_model()
    model.set_weights(weights)
    model.fit(X, y, epochs=5, batch_size=128)
    return model.get_weights()

# Define the main training function
def train_cnn(X_train, y_train, num_nodes):
    # Split the data across the nodes
    X_split = np.array_split(X_train, num_nodes)
    y_split = np.array_split(y_train, num_nodes)

    # Initialize the model weights
    model = create_model()
    model_weights = model.get_weights()

    # Train each layer on a different node
    for i in range(2):
        node_weights = ray.get([train_layer.remote(X_split[j], y_split[j], model_weights) for j in range(i, num_nodes, 2)])
        model_weights = [sum(weights) / len(weights) for weights in zip(*node_weights)]

    # Evaluate the model on the test data
    _, accuracy = model.evaluate(x_test, y_test)
    print('Test accuracy:', accuracy)

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize the pixel values to be between 0 and 1
x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

# Add a channel dimension to the images
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# Convert the labels to one-hot encoding
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

# Initialize Ray
ray.init()

# Train the model using Ray
train_cnn(x_train, y_train, 2)


2023-03-28 19:30:50,137	INFO worker.py:1553 -- Started a local Ray instance.
[2m[36m(pid=7285)[0m 2023-03-28 19:30:53.553097: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia
[2m[36m(pid=7285)[0m 2023-03-28 19:30:53.553223: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/lib64-nvidia
[2m[36m(train_layer pid=7285)[0m 2023-03-28 19:30:54.889820: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


[2m[36m(train_layer pid=7285)[0m Epoch 1/5
  5/235 [..............................] - ETA: 2s - loss: 2.2573 - accuracy: 0.2313  
 15/235 [>.............................] - ETA: 2s - loss: 2.0721 - accuracy: 0.4474
 25/235 [==>...........................] - ETA: 2s - loss: 1.7739 - accuracy: 0.5612
 35/235 [===>..........................] - ETA: 2s - loss: 1.4827 - accuracy: 0.6319
 45/235 [====>.........................] - ETA: 2s - loss: 1.2713 - accuracy: 0.6776
 50/235 [=====>........................] - ETA: 2s - loss: 1.1899 - accuracy: 0.6952
[2m[36m(train_layer pid=7285)[0m Epoch 2/5
  6/235 [..............................] - ETA: 2s - loss: 0.1441 - accuracy: 0.9596
 16/235 [=>............................] - ETA: 2s - loss: 0.1258 - accuracy: 0.9653
 26/235 [==>...........................] - ETA: 2s - loss: 0.1216 - accuracy: 0.9657
 35/235 [===>..........................] - ETA: 2s - loss: 0.1226 - accuracy: 0.9652
 40/235 [====>.........................] - ETA: 2s - los