## Task 3: Mixture of Experts with Ray

This assignment focuses on implementing a distributed Mixture of Experts (MoE) model using Ray. To help you understand the workflow, we provide a reference implementation called `SimpleMoE`. You will then complete the skeleton for the `MoE_TP` class, which incorporates tensor parallelism (TP) into MoE.

In [None]:
import numpy as np
import ray
from rng import get_rng, rng_context, reset_rngs
import time
ray.init(ignore_reinit_error=True, num_cpus=8)

# For simplicity, we assume that `hidden_dim` and `output_dim`
# are evenly divisible by `num_workers`.
params={
    "batch_size": 1000,
    "feature_dim": 1000,
    "hidden_dim": 1000,
    "output_dim": 1000,
    "num_experts": 10,
    "topk": 2,
}
num_workers = 10

## Simple_MoE

Mixture of Experts (MoE) is a neural network architecture where multiple “experts” (sub-networks) exist in parallel. A gating network selects which experts should process each input, and `top_k` specifies how many experts are chosen. This approach can greatly increase model capacity without a proportional increase in computation per sample.

In [None]:
class Linear:
    """Simple linear layer y = xW + b"""

    def __init__(self, in_features, out_features):
        self.weight = get_rng().randn(in_features, out_features) * 0.01
        self.bias = get_rng().randn(out_features)

    def __call__(self, x):
        return np.dot(x, self.weight) + self.bias


class Expert:
    """Expert network with one hidden layer and ReLU activation"""

    def __init__(self, input_dim, hidden_dim, output_dim):
        self.fc1 = Linear(input_dim, hidden_dim)
        self.fc2 = Linear(hidden_dim, output_dim)

    def __call__(self, x):
        hidden = self.fc1(x)
        hidden = np.maximum(0, hidden)  # ReLU
        return self.fc2(hidden)


class Router:
    """Routes inputs to experts using softmax-based gating"""

    def __init__(self, input_dim, num_experts):
        self.linear = Linear(input_dim, num_experts)

    def __call__(self, x, topk=1):
        logits = self.linear(x)

        # Softmax for routing probabilities
        exp_logits = np.exp(logits - np.max(logits, axis=1, keepdims=True))
        probs = exp_logits / np.sum(exp_logits, axis=1, keepdims=True)

        # Select top-k experts
        indices = np.argsort(-probs, axis=1)[:, :topk]
        gates = np.take_along_axis(probs, indices, axis=1)

        # Normalize gates to sum to 1
        gates = gates / np.sum(gates, axis=1, keepdims=True)

        return indices, gates


class SimpleMoE:
    """
    Simple reference implementation of Mixture of Experts.
    
    Args:
        input_dim (int): Input feature dimension
        hidden_dim (int): Hidden dimension for each expert
        output_dim (int): Output dimension
        num_experts (int): Number of expert networks
        topk (int): Number of experts to route each input to
    """

    def __init__(self, input_dim, hidden_dim, output_dim, num_experts, topk=1):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_experts = num_experts
        self.topk = min(topk, num_experts)

        with rng_context('router'):
            self.router = Router(input_dim, num_experts)

        with rng_context('expert'):
            self.experts = [Expert(input_dim, hidden_dim, output_dim)
                            for _ in range(num_experts)]

    def forward(self, x):
        batch_size = x.shape[0]
        indices, gates = self.router(x, self.topk)
        outputs = np.zeros((batch_size, self.output_dim))

        for k in range(self.topk):
            for i in range(batch_size):
                expert_idx = indices[i, k]
                gate = gates[i, k]
                item = x[i:i + 1]
                expert_output = self.experts[expert_idx](item)
                outputs[i] += gate * expert_output[0]

        return outputs

    def __call__(self, x):
        return self.forward(x)


In [None]:
def test_simple_moe(batch_size=10, feature_dim=10, hidden_dim=10, output_dim=10, num_experts=1, topk=1):
    """Test SimpleMoE for correctness and performance"""
    reset_rngs()
    
    # Generate input data
    with rng_context("testing"):
        X = get_rng().randn(batch_size, feature_dim)
    
    moe = SimpleMoE(
        input_dim=feature_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        num_experts=num_experts,
        topk=topk
    )
    
    # Warm up to better measure efficiency
    _ = moe(X)
    
    # Measure time
    start_time = time.time()
    output = moe(X)
    end_time = time.time()
    avg_duration_ms = end_time - start_time
    
    return output, avg_duration_ms

result_simple, time_simple = test_simple_moe(**params)
print(f"Simple MoE:")
print(f"  Avg time: {time_simple:.2f} ms")

## Task 3.1: MoE with Tensor Parallel

**Tensor Parallelism (TP)** is a distributed training technique where a single layer (e.g., a large matrix multiplication) is split across multiple devices. Each device computes a partial result, and the partial results are combined to produce the full output.

In **MoE with Tensor Parallelism (MoE-TP)**, these ideas are combined:

- Each expert consists of multiple layers, and the parameters of every layer are sharded across all workers (i.e., each worker stores only a slice of the layer’s weights).  
  In this assignment, we assume there are 10 experts, each containing 2 fully connected layers, for a total of 20 layers. Under tensor parallelism, each worker maintains a shard from every one of these 20 layers.

- Every worker computes its partial output for each layer (based on the parameters it stores).

- The partial outputs are then combined by concatenation to obtain the full expert outputs. See `ShardedLinear` class for how a sharded linear layer works.

- This design is well-suited for balanced workloads where all experts need to be computed in parallel.

In [None]:
@ray.remote
class LinearShardWorker:
    """
    Generic Ray worker that holds partial weight shards for all experts 
    Each fc layer from each expert is represented by a unique layer_id.
    """
    
    def __init__(self, rank, world_size):
        self.rank = rank
        self.world_size = world_size
        self.layers = {}  # Store weights for multiple layers by layer_id
    
    def initialize_layer(self, layer_id, weight_shard, bias_shard):
        """
        Store weight shards for a specific layer.
        
        Args:
            layer_id: Unique identifier for this layer
            weight_shard: Pre-computed weight shard for this worker
            bias_shard: Pre-computed bias shard for this worker
        """
        local_out_features = weight_shard.shape[1]
        self.layers[layer_id] = {'weight': weight_shard, 'bias': bias_shard, 'local_out_features': local_out_features}
        return local_out_features
    
    def forward_layer(self, layer_id, x):
        """
        Compute local shard of a layer's output.
        
        Args:
            layer_id: Identifier for which layer to use
            x: Input of shape (batch_size, in_features)
            
        Returns:
            Local output shard of shape (batch_size, out_features // world_size)
        """
        layer = self.layers[layer_id]
        if x.shape[0] == 0:
            return np.zeros((0, layer['local_out_features']))
        
        # Perform local computation
        local_output = np.dot(x, layer['weight']) + layer['bias']
        return local_output


class ShardedLinear:
    """
    Linear layer that is sharded across Ray workers.
    
    Each worker holds a shard, __call__ dispatches + concatenates
    """

    # _layer_counter: Class-level counter for unique ShardedLinear IDs
    # (In MoE_TP, you have multiple experts (e.g., 10 experts), each expert has 2 layers (fc1 and fc2),
    # and all layers share the same workers. So the same Ray worker holds shards for:
    # - Expert 0's fc1
    # - Expert 0's fc2
    # - Expert 1's fc1
    # - Expert 1's fc2
    # - ... 20 layers total for 10 experts!)
    _layer_counter = 0
    
    def __init__(self, in_features, out_features, workers):
        self.workers = workers
        self.world_size = len(workers)
        
        # Assert that out_features is evenly divisible by world_size
        assert out_features % self.world_size == 0, \
            f"Output features ({out_features}) must be evenly divisible by world size ({self.world_size})"
        
        # Calculate the local output dimension
        self.out_features_global = out_features
        self.local_out_features = out_features // self.world_size
        
        # Assign unique layer ID
        self.layer_id = f"layer_{ShardedLinear._layer_counter}"
        ShardedLinear._layer_counter += 1
        
        # Generate full weights, then shard them
        # This ensures each shard corresponds to the same weights as in SimpleMoE 
        # generated by random number generator (rng). 
        full_weight = get_rng().randn(in_features, out_features) * 0.01
        full_bias = get_rng().randn(out_features)
        
        # Distribute shards to workers
        futures = []
        for rank, worker in enumerate(self.workers):
            offset = rank * self.local_out_features

            ### Compute weight_shard and bias_shard, then call worker.initialize_layer.remote
            weight_shard = full_weight[:, offset:offset + self.local_out_features]
            bias_shard = full_bias[offset:offset + self.local_out_features]
            futures.append(
                worker.initialize_layer.remote(self.layer_id, weight_shard, bias_shard)
            )
        ray.get(futures)  # Wait for initialization
    
    def __call__(self, x):
        """
        Forward pass through sharded linear layer.
        """
        # Handle empty batch case
        if x.shape[0] == 0:
            return np.zeros((0, self.out_features_global))
        
        # Compute partial output from each worker
        # YOUR CODE HERE
        raise NotImplementedError()

        shards = ray.get(futures)
        
        # Concatenate shards
        result = np.concatenate(shards, axis=1)
        
        return result


class ShardedExpert:
    """
    Expert network with one hidden layer and ReLU activation, sharded across workers.
    """
    
    def __init__(self, input_dim, hidden_dim, output_dim, workers):
        # Initialize two ShardedLinear layers
        # YOUR CODE HERE
        raise NotImplementedError()
    
    def __call__(self, x):
        hidden = self.fc1(x)
        hidden = np.maximum(0, hidden)  # ReLU
        return self.fc2(hidden)

In [None]:
class MoE_TP:
    """
    Distributed Mixture of Experts using Ray for tensor parallelism.
    
    TP-style MoE:
    - Each worker holds a portion of every expert (sharded experts)
    - Router is replicated
    - ShardedExpert classes coordinate worker communication
    
    Args:
        input_dim (int): Input feature dimension
        hidden_dim (int): Hidden dimension for each expert
        output_dim (int): Output dimension
        num_experts (int): Total number of experts in the model
        num_workers (int): Number of parallel workers
        topk (int): Number of experts to route each input to
    """
    
    def __init__(self, input_dim, hidden_dim, output_dim, num_experts, num_workers=4, topk=1):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.num_experts = num_experts
        self.topk = min(topk, num_experts)
        self.num_workers = num_workers
        
        # Validate dimensions
        assert hidden_dim % num_workers == 0, \
            f"hidden_dim ({hidden_dim}) must be divisible by num_workers ({num_workers})"
        assert output_dim % num_workers == 0, \
            f"output_dim ({output_dim}) must be divisible by num_workers ({num_workers})"
        
        # Initialize Ray if not already initialized
        if not ray.is_initialized():
            ray.init(ignore_reinit_error=True)
        
        # Create router
        with rng_context('router'):
            self.router = Router(input_dim, num_experts)
        
        # Each worker will hold shards of all experts
        self.workers = [
            LinearShardWorker.remote(rank, num_workers)
            for rank in range(num_workers)
        ]

        # Create sharded experts - each expert is sharded across all workers
        with rng_context('expert'):
            self.experts = [
                ShardedExpert(input_dim, hidden_dim, output_dim, self.workers)
                for _ in range(num_experts)
            ]
        
        print(f"Initialized MoE_TP with {num_experts} experts, each sharded across {num_workers} workers")

    def forward(self, x):
        """
        Distributed forward pass through the MoE model using tensor parallelism.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        batch_size = x.shape[0]
        
        # All processes compute routing (router is replicated)
        indices, gates = self.router(x, self.topk)
        
        # Initialize output tensor
        outputs = np.zeros((batch_size, self.output_dim))

        # Process one expert at a time
        for expert_idx in range(self.num_experts):
            # Find all batch_idx combinations that route to this expert, and their corresponding gates
            batch_indices = []
            expert_gates = []
            # YOUR CODE HERE
            raise NotImplementedError()
            
            # If no samples are routed to this expert, skip it
            if not batch_indices:
                continue
                
            # Create a batch of inputs for this expert
            expert_inputs = x[batch_indices]
            expert_gates = np.array(expert_gates)[:, np.newaxis]  # Shape: (num_samples, 1)
            
            # Process expert_inputs through this expert
            # Then scale outputs by their gates.
            # Add the gated outputs to the result tensor (outputs) based on batch_indices
            # YOUR CODE HERE
            raise NotImplementedError()
            
        return outputs

    def __call__(self, x):
        return self.forward(x)

    def shutdown(self):
        """Cleanup Ray actors"""
        for worker in self.workers:
            ray.kill(worker)


In [None]:
def test_tp_moe(batch_size=10, feature_dim=10, hidden_dim=10, output_dim=10, num_experts=1, num_workers=1, topk=1):
    """Test MoE_TP for correctness and performance"""
    reset_rngs()
    
    # Generate input data
    with rng_context("testing"):
        X = get_rng().randn(batch_size, feature_dim)
    
    moe = MoE_TP(
        input_dim=feature_dim,
        hidden_dim=hidden_dim,
        output_dim=output_dim,
        num_experts=num_experts,
        num_workers=num_workers,
        topk=topk
    )
    
    # Warm up to better measure efficiency
    _ = moe(X)
    
    # Measure time
    start_time = time.time()
    output = moe(X)
    end_time = time.time()
    avg_duration_ms = end_time - start_time
    
    # Cleanup
    if hasattr(moe, 'shutdown'):
        moe.shutdown()
    
    return output, avg_duration_ms


result_tp, time_tp = test_tp_moe(**params, num_workers=num_workers)
print(f"\nTP MoE:")
print(f"  Avg time: {time_tp:.2f} ms")
assert time_simple / time_tp > 2, f"Tensor Parallel MoE time not efficient compared with Simple MoE."
print(f"  ✅ Efficiency test passed, Speed up > 2")

In [None]:
assert abs(result_tp.sum() - result_simple.sum()) <= 0.1, f"Tensor Parallel MoE test failed: {abs(result_tp.sum() - result_simple.sum())}"
print(f"  ✅ Correctness test passed")