import mlx.core as mx import mlx.nn as nn import mlx.optimizers as optim from mlx.utils import tree_map import numpy as np import os import socket hostname = socket.gethostname() world = mx.distributed.init() mx.set_default_device(mx.gpu) class MLP(nn.Module): def __init__(self, in_dims: int, out_dims: int): super().__init__() self.layers = [ nn.Linear(in_dims, 128), nn.Linear(128, 128), nn.Linear(128, 128), nn.Linear(128, 128), nn.Linear(128, 128), nn.Linear(128, out_dims), ] def __call__(self, x): for i, l in enumerate(self.layers): x = mx.maximum(x, 0) if i > 0 else x x = l(x) return x # Function to generate a synthetic dataset with a quadratic relationship def generate_dataset(num_samples, in_dims, out_dims): X = np.linspace(-1, 1, num_samples * in_dims).reshape(num_samples, in_dims) Y = X**2 + 0.001 * np.random.randn(num_samples, out_dims) return mx.array(X), mx.array(Y) # Generate a synthetic dataset with 1000 samples total_num_samples = 1000 in_dims = 200000 out_dims = 200000 # Initialize distributed settings num_processes = world.size() rank = world.rank() # Calculate number of samples per process num_samples_per_process = total_num_samples // num_processes # Each process generates its portion of the dataset X, Y = generate_dataset(num_samples_per_process, in_dims, out_dims) # The model is created with all its parameters but nothing is initialized yet because MLX is lazily evaluated model = MLP(in_dims, out_dims) print(f"Distributed available: {mx.distributed.is_available()}") print(f"Hostname: {hostname}: {rank}") print(f"Number of processes: {num_processes}") print(f"Number of Samples Per Process: {num_samples_per_process}") # We can also force evaluate all parameters to initialize the model mx.eval(model.parameters()) optimizer = optim.Adam(learning_rate=0.0001) # A simple loss function. def l2_loss(model, x, y): y_hat = model(x) return mx.array(y_hat - y).square().mean() def all_reduce_grads(grads, N): if N == 1: return grads return tree_map( lambda x: mx.distributed.all_sum(x) / N, grads) def step(model, x, y): loss_and_grad_fn = nn.value_and_grad(model, l2_loss) loss, grads = loss_and_grad_fn(model, x, y) return loss, grads # Attempt a barrier using all_gather def barrier(): while True: gathered = mx.array([0.0]) sync_tensor = mx.array([rank], dtype=mx.int32) gathered = mx.distributed.all_gather(sync_tensor) # Ensure all processes have the same gathered data if len(gathered) == num_processes and all(gathered == np.arange(num_processes)): break # Training loop with mini-batch processing batch_size = 10 num_batches = num_samples_per_process // batch_size for epoch in range(20): # Number of epochs epoch_loss = mx.array([0.0]) for i in range(num_batches): x_batch = X[i*batch_size:(i+1)*batch_size] y_batch = Y[i*batch_size:(i+1)*batch_size] #print(f"Batch {i}") # Accumulate gradients over the mini-batch batch_loss = mx.array([0.0]) batch_grads = None for j in range(batch_size): #print(f"Sample {j}") x = x_batch[j] y = y_batch[j] loss, grads = step(model, x, y) batch_loss += loss if batch_grads is None: batch_grads = grads else: batch_grads = tree_map(lambda g1, g2: g1 + g2, batch_grads, grads) # Average the gradients over the batch batch_grads = tree_map(lambda g: g / batch_size, batch_grads) # All-reduce to average gradients across all processes batch_grads = all_reduce_grads(batch_grads, num_processes) # Update the model with the averaged gradients optimizer.update(model, batch_grads) epoch_loss += batch_loss print(f"Rank: {rank} about to print") # Trying to print causes the code to hang if rank == 0: print(f"Epoch {epoch + 1}, Loss: {mx.distributed.all_sum(epoch_loss) / num_batches}") barrier() barrier() print(f"Rank: {rank} finished")