In [None]:
experts = {
    "GPT-4": gpt4_model,
    "Fraud_Detection_AI": fraud_detection_model,
    "Risk_Assessment_AI": risk_assessment_model,
    "OCR_Model": ocr_model
}

In [None]:
# Build the Gating Network (Neural Network Classifier)
# We use a simple MLP (Multi-Layer Perceptron) with Softmax to select models.

import torch
import torch.nn as nn
import torch.nn.functional as F

class GatingNetwork(nn.Module):
    def __init__(self, input_dim, num_experts):
        super(GatingNetwork, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)   # Hidden layer
        self.fc2 = nn.Linear(64, num_experts)  # Output layer (expert selection)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)  # Softmax assigns probabilities
        return x  # Probability distribution over experts


## Training the Gating Network

In [None]:
# Example: Training Data
X_train = torch.randn(1000, 10)  # 1000 samples, 10 features
y_train = torch.randint(0, len(experts), (1000,))  # Target expert indices

# Define model, loss, and optimizer
gating_model = GatingNetwork(input_dim=10, num_experts=len(experts))
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gating_model.parameters(), lr=0.01)

# Training loop
for epoch in range(100):
    optimizer.zero_grad()
    outputs = gating_model(X_train)
    loss = criterion(outputs, y_train)
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}: Loss = {loss.item()}")


## Using the Gating Network for Model Selection

In [None]:
def select_experts(input_features):
    gating_output = gating_model(torch.tensor(input_features, dtype=torch.float32))
    top_experts = torch.topk(gating_output, k=2)  # Select top 2 experts
    selected_models = [list(experts.keys())[idx] for idx in top_experts.indices]
    return selected_models

# Example Request
input_features = torch.randn(1, 10)  # Simulated underwriting request
chosen_models = select_experts(input_features)
print("Selected Models:", chosen_models)
