In [11]:
from expert import FeedForwardExpert
import torch
import torch.nn as nn
import torch.nn.functional as F
class FeedForward(nn.Module):
    def __init__(self, hidden_dim, num_of_experts, top_k) -> None:
        super().__init__()
        self.top_k = top_k
        self.experts = nn.ModuleList([FeedForwardExpert(hidden_dim) for i in range(num_of_experts)])
        self.gate = nn.Linear(hidden_dim, num_of_experts)


    def forward(self, x):
        batch_size = x.shape[0]
        seq_len = x.shape[1]

        gate_output = self.gate(x)

        # get probabilities for each expert
        gate_output = F.softmax(gate_output, dim=-1)
        
        # get top k experts
        top_k_experts, expert_indices = torch.topk(gate_output, self.top_k, dim=-1)

        # re-normalize probabilities for top k experts
        top_k_experts_weights = top_k_experts / torch.sum(top_k_experts, dim=-1, keepdim=True)

        # place holder for output
        expert_outputs = torch.zeros_like(x)

        for batch in range(batch_size):
            for tok_pos in range(seq_len):
                for k in range(self.top_k):
                    expert_index = expert_indices[batch, tok_pos, k].item()
                    curent_expert_output = self.experts[expert_index](x[batch, tok_pos])
                    expert_outputs[batch, tok_pos, k] = curent_expert_output * top_k_experts_weights[batch, tok_pos, k]
        

        return expert_outputs


In [12]:
model = FeedForward(512, 8, 2)

# set up dummy input
# batch_size, seq_len, hidden_dim
x = torch.randn(1, 5, 512)

model(x)


torch.Size([512])


RuntimeError: expand(torch.FloatTensor{[512]}, size=[]): the number of sizes provided (0) must be greater or equal to the number of dimensions in the tensor (1)