In [1]:
%pip install "ray[default]"
%pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

Collecting aiohttp-cors (from ray[default])
  Using cached aiohttp_cors-0.7.0-py3-none-any.whl (27 kB)
Collecting colorful (from ray[default])
  Using cached colorful-0.5.5-py2.py3-none-any.whl (201 kB)
Collecting py-spy>=0.2.0 (from ray[default])
  Using cached py_spy-0.3.14-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (3.0 MB)
Collecting gpustat>=1.0.0 (from ray[default])
  Using cached gpustat-1.1-py3-none-any.whl
Collecting opencensus (from ray[default])
  Using cached opencensus-0.11.2-py2.py3-none-any.whl (128 kB)
Collecting pydantic (from ray[default])
  Downloading pydantic-1.10.8-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m46.2 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Collecting smart-open (from ray[default])
  Using cached smart_open-6.3.0-py3-none-any.whl (56 kB)
Collecting nvidia-ml-py>=11.450.129 (from gpustat>=1.0.0->ray[default])
  Using cached nv

In [1]:
import ray
ray.init(
    address="ray://example-cluster-kuberay-head-svc:10001"
    #address="ray://raycluster-complete-head-svc:10001"
)

0,1
Python version:,3.10.9
Ray version:,2.3.0
Dashboard:,http://10.48.0.30:8265


In [2]:
from pprint import pprint

nodelist = []
node_ids = []
for node in ray.nodes():
    nodelist.append(node["NodeManagerAddress"])
    node_ids.append(node["NodeID"])
    pprint(node)
    
#pprint(ray.nodes())
tpu_worker_list = ",".join(nodelist)
num_workers = len(nodelist)
pprint ("{} total workers: {}".format(num_workers, node_ids))
print("TPU_WORKER_HOSTNAMES={}".format(tpu_worker_list))

{'Alive': True,
 'MetricsExportPort': 8080,
 'NodeID': '3caa95c01d80677ca6b5ef899e10aa3eb19448166bd2a4fe361b1e88',
 'NodeManagerAddress': '10.48.0.30',
 'NodeManagerHostname': 'example-cluster-kuberay-head-bbxhl',
 'NodeManagerPort': 41741,
 'NodeName': '10.48.0.30',
 'ObjectManagerPort': 45843,
 'ObjectStoreSocketName': '/tmp/ray/session_2023-05-25_18-39-38_622164_8/sockets/plasma_store',
 'RayletSocketName': '/tmp/ray/session_2023-05-25_18-39-38_622164_8/sockets/raylet',
 'Resources': {'CPU': 8.0,
               'google.com/tpu': 4.0,
               'memory': 20000000000.0,
               'node:10.48.0.30': 1.0,
               'object_store_memory': 19000001945.0},
 'alive': True}
"1 total workers: ['3caa95c01d80677ca6b5ef899e10aa3eb19448166bd2a4fe361b1e88']"
TPU_WORKER_HOSTNAMES=10.48.0.30


In [3]:
import jax
import os
from jax import numpy as jnp, random

actor_runtime_vars = {
    "env_vars": {
        "TPU_SKIP_MDS_QUERY": "true",
        "TPU_WORKER_HOSTNAMES": tpu_worker_list,
        "JAX_PLATFORMS": "tpu",
    },
}

#print(jax.device_count())
#assert jax.devices()[0].platform == 'tpu'
#jax.distributed.initialize()

# Define the neural network architecture
input_size = 10
hidden_size = 5
output_size = 2

@ray.remote(scheduling_strategy="SPREAD", resources={"google.com/tpu": 4})
class JaxNNActor:
    
    def __init__(self, worker_id, num_samples, learning_rate, num_epochs):
        self.worker_id = worker_id
        self.num_samples = num_samples
        self.learning_rate = learning_rate
        self.num_epochs = num_epochs
        pass
    
    def initialize(self):
        key = random.PRNGKey(0)
        self.params = {
            'W1': random.normal(key, (input_size, hidden_size)),
            'b1': jnp.zeros(hidden_size),
            'W2': random.normal(key, (hidden_size, output_size)),
            'b2': jnp.zeros(output_size)
        }
          
    # Define the forward pass of the network
    def forward(self, params, x):
        hidden = jax.nn.relu(jnp.dot(x, params['W1']) + params['b1'])
        output = jnp.dot(hidden, params['W2']) + params['b2']
        return output

    # Define the loss function
    def loss(self, params, x, y):
        y_pred = self.forward(params, x)
        return jnp.mean((y_pred - y)**2)

    # Define the training function
    #@jax.jit -- instead of the jax.jit decorator, we'll jit the function explicitly in the run, since this gets sent remotely to the ray worker
    def train_step(self, params, x, y, learning_rate, key):
        grads = jax.grad(self.loss)(params, x, y)
        new_params = jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
        return new_params, key

    def run(self):
        self.initialize()
        # Generate some example data
        key = random.PRNGKey(0)
        x = random.normal(key, (self.num_samples, input_size))
        y = random.normal(key, (self.num_samples, output_size))
        
        # jit the train function
        train_step_jit = jax.jit(self.train_step)
        
        params = self.params
        # Train the network
        for epoch in range(self.num_epochs):
            if self.worker_id == 0 and epoch % 100 == 0:
                print("training epoch {}".format(epoch))
            key, subkey = random.split(key)
            params, _ = train_step_jit(params, x, y, self.learning_rate, subkey)
            
        if self.worker_id == 0:
            print("training completed")

    def predict(self):
        print("my pid: {}".format(os.getpid()))
        key = random.PRNGKey(0)
        # Use the trained network to make predictions on new data
        key, subkey = random.split(key)
        new_x = random.normal(subkey, (1, input_size))    
        new_y = self.forward(self.params, new_x)
        print(new_y)
            
        if self.worker_id == 0:
            return new_y
        
        return None

In [7]:
num_samples = 100
learning_rate = 0.1
num_epochs = 1000

actor_refs = [JaxNNActor.options(
    runtime_env={
        "env_vars":{
            "TPU_WORKER_ID":str(i), 
            "TPU_WORKER_HOSTNAMES": tpu_worker_list, 
            "TPU_SKIP_MDS_QUERY": "true",
            "JAX_PLATFORMS": "",
            "TF_CPP_MIN_LOG_LEVEL": "0"
        },
    },
    scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
        node_id=node_ids[i],
        soft=False,
    ),
).remote(i, num_samples, learning_rate, num_epochs) for i in range(num_workers)]

In [10]:
refs = []
for i in range(len(actor_refs)):
    refs.append(actor_refs[i].run.remote())
print(refs)
ray.get(refs)

[ClientObjectRef(0ae631196af8e7cfb95f5b49171b4a6877e10a4f0300000001000000)]
[2m[36m(JaxNNActor pid=5947)[0m training epoch 0
[2m[36m(JaxNNActor pid=5947)[0m training epoch 100
[2m[36m(JaxNNActor pid=5947)[0m training epoch 200
[2m[36m(JaxNNActor pid=5947)[0m training epoch 300
[2m[36m(JaxNNActor pid=5947)[0m training epoch 400
[2m[36m(JaxNNActor pid=5947)[0m training epoch 500
[2m[36m(JaxNNActor pid=5947)[0m training epoch 600


[None]

[2m[36m(JaxNNActor pid=5947)[0m training epoch 700
[2m[36m(JaxNNActor pid=5947)[0m training epoch 800
[2m[36m(JaxNNActor pid=5947)[0m training epoch 900
[2m[36m(JaxNNActor pid=5947)[0m training completed


In [11]:
ray.get(actor_refs[0].predict.remote())

Unable to initialize backend 'tpu': ALREADY_EXISTS: PJRT_Api already exists for device type tpu (set JAX_PLATFORMS='' to automatically choose an available backend)


[2m[36m(JaxNNActor pid=5947)[0m my pid: 5947
[2m[36m(JaxNNActor pid=5947)[0m [[-0.37971944 -1.6159217 ]]


RaySystemError: System error: Unable to initialize backend 'tpu': ALREADY_EXISTS: PJRT_Api already exists for device type tpu (set JAX_PLATFORMS='' to automatically choose an available backend)
traceback: Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 435, in backends
    backend = _init_backend(platform)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 488, in _init_backend
    backend = factory()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 189, in tpu_client_timer_callback
    client = xla_client.make_tpu_client()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jaxlib/xla_client.py", line 172, in make_tpu_client
    load_pjrt_plugin_dynamically('tpu', library_path)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jaxlib/xla_client.py", line 135, in load_pjrt_plugin_dynamically
    _xla.load_pjrt_plugin(plugin_name, library_path)
jaxlib.xla_extension.XlaRuntimeError: ALREADY_EXISTS: PJRT_Api already exists for device type tpu

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/serialization.py", line 369, in deserialize_objects
    obj = self._deserialize_object(data, metadata, object_ref)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/serialization.py", line 252, in _deserialize_object
    return self._deserialize_msgpack_data(data, metadata_fields)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/serialization.py", line 207, in _deserialize_msgpack_data
    python_objects = self._deserialize_pickle5_data(pickle5_data)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/ray/_private/serialization.py", line 197, in _deserialize_pickle5_data
    obj = pickle.loads(in_band)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/array.py", line 106, in _reconstruct_array
    jnp_value = api.device_put(np_value)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/api.py", line 2448, in device_put
    return tree_map(
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/tree_util.py", line 210, in tree_map
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/tree_util.py", line 210, in <genexpr>
    return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/api.py", line 2449, in <lambda>
    lambda y: dispatch.device_put_p.bind(
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/core.py", line 380, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/core.py", line 383, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/core.py", line 790, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/dispatch.py", line 675, in _device_put_impl
    sh = SingleDeviceSharding(pxla._get_default_device()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 1761, in _get_default_device
    return config.jax_default_device or xb.local_devices()[0]
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 631, in local_devices
    process_index = get_backend(backend).process_index()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 533, in get_backend
    return _get_backend_uncached(platform)
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 514, in _get_backend_uncached
    bs = backends()
  File "/home/ray/anaconda3/lib/python3.10/site-packages/jax/_src/xla_bridge.py", line 452, in backends
    raise RuntimeError(err_msg)
RuntimeError: Unable to initialize backend 'tpu': ALREADY_EXISTS: PJRT_Api already exists for device type tpu (set JAX_PLATFORMS='' to automatically choose an available backend)


In [None]:
for i in range(num_workers):
    print ("terminating actor {}".format(i))
    ray.kill(actor_refs[i])