In [None]:
pip install ray

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import ray
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

ray.init()

# Define the CNN model
def build_model():
    model = keras.Sequential()
    model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Conv2D(64, kernel_size=(3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(10, activation='softmax'))
    model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
    return model

# Define the training function
@ray.remote
def train_cnn(X, y, model_weights):
    model = build_model()
    model.set_weights(model_weights)
    model.fit(X, y, batch_size=128, epochs=5)
    new_weights = model.get_weights()
    return new_weights

# Generate some sample data
(X_train, y_train), _ = keras.datasets.mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1).astype('float32') / 255
y_train = keras.utils.to_categorical(y_train, 10)

import numpy as np
# Split the data into 2 parts
X_split = np.array_split(X_train, 2)
y_split = np.array_split(y_train, 2)

# Train the CNN model on 2 different nodes
num_nodes = 2
model = build_model()
model_weights = model.get_weights()
node_weights = ray.get([train_cnn.remote(X_split[j], y_split[j], model_weights) for j in range(num_nodes)])
for i in range(num_nodes):
    print(f"Node {i}: {node_weights[i]}")

# Combine the weights from the different nodes
new_weights = []
for i in range(len(model_weights)):
    layer_weights = []
    for j in range(num_nodes):
        layer_weights.append(node_weights[j][i])
    new_weights.append(np.mean(layer_weights, axis=0))
model.set_weights(new_weights)

# Evaluate the model on the test data
(_, _), (X_test, y_test) = keras.datasets.mnist.load_data()
X_test = X_test.reshape(-1, 28, 28, 1).astype('float32') / 255
y_test = keras.utils.to_categorical(y_test, 10)
test_loss, test_acc = model.evaluate(X_test, y_test, verbose=0)
print(f"Test loss: {test_loss}, Test accuracy: {test_acc}")


2023-03-28 19:51:27,764	INFO worker.py:1553 -- Started a local Ray instance.
[2m[36m(pid=3080)[0m 2023-03-28 19:51:32.521222: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
[2m[36m(pid=3080)[0m 2023-03-28 19:51:32.521390: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
[2m[36m(pid=3079)[0m 2023-03-28 19:51:32.519240: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or di

[2m[36m(train_cnn pid=3079)[0m Epoch 1/5
[2m[36m(train_cnn pid=3080)[0m Epoch 1/5
  1/235 [..............................] - ETA: 10:41 - loss: 2.3019 - accuracy: 0.0703
  2/235 [..............................] - ETA: 56s - loss: 2.2902 - accuracy: 0.1484  
  3/235 [..............................] - ETA: 55s - loss: 2.2842 - accuracy: 0.1797
  4/235 [..............................] - ETA: 1:00 - loss: 2.2719 - accuracy: 0.2090
  1/235 [..............................] - ETA: 12:36 - loss: 2.2984 - accuracy: 0.0938
  5/235 [..............................] - ETA: 1:02 - loss: 2.2617 - accuracy: 0.2422
  2/235 [..............................] - ETA: 1:18 - loss: 2.2945 - accuracy: 0.1133 
  3/235 [..............................] - ETA: 1:13 - loss: 2.2804 - accuracy: 0.1771
  6/235 [..............................] - ETA: 1:06 - loss: 2.2454 - accuracy: 0.2747
  4/235 [..............................] - ETA: 1:13 - loss: 2.2681 - accuracy: 0.2168
  7/235 [..............................