In [33]:
import torch
import torch.nn as nn

class BayesFactorLayer(nn.Module):
    def __init__(self, num_input_neurons, num_output_neurons):
        super(BayesFactorLayer, self).__init__()
        self.num_input_neurons = num_input_neurons
        self.num_output_neurons = num_output_neurons
        
        # Initialize parameters for the layer
        self.mu_1 = nn.Parameter(torch.randn(num_output_neurons, num_input_neurons))
        self.mu_2 = nn.Parameter(torch.randn(num_output_neurons, num_input_neurons))
        self.var_1 = nn.Parameter(torch.ones(num_output_neurons, num_input_neurons))
        self.var_2 = nn.Parameter(torch.ones(num_output_neurons, num_input_neurons))
        
        # Initialize biases for the output neurons
        self.bias_mean = nn.Parameter(torch.zeros(1, num_output_neurons))
        # self.bias_var = nn.Parameter(torch.ones(1, num_output_neurons))
    
    def forward(self, mu_x, var_x):
        var_x_expanded = var_x.unsqueeze(1)
        mu_x_expanded = mu_x.unsqueeze(1)
        
        v1x, v2x = self.var_1 + var_x_expanded, self.var_2 + var_x_expanded
        s41, s42 = torch.square(self.var_1), torch.square(self.var_2)
        emu1, emu2, emux = torch.exp(self.mu_1), torch.exp(self.mu_2), torch.exp(mu_x_expanded)
        e1x, e2x = emu1 * emux, emu2 * emux
        e1p1, e2p1, e1xp1, e2xp1 = emu1 + 1, emu2 + 1, e1x + 1, e2x + 1
        e1op1, e2op1, e1xop1, e2xop1 = emu1 / e1p1, emu2 / e2p1, e1x / e1xp1, e2x / e2xp1
        e1op12, e2op12, e1xop12, e2xop12 = e1op1 / e1p1, e2op1 / e2p1, e1xop1 / e1xp1, e2xop1 / e2xp1
        
        # Mean of log Bayes factor
        mu_out = torch.log(e2p1 * e1xp1 / (e1p1 * e2xp1))
        mu_out += (self.var_2 * emu2 / (e2p1 ** 2) + (v1x) * e1x / (e1xp1 ** 2)) / 2
        mu_out -= (self.var_1 * emu1 / (e1p1 ** 2) + (v2x) * e2x / (e2xp1 ** 2)) / 2
        
        # Variance of log Bayes factor
        var_out = self.var_1 * torch.square(e1op1) + self.var_2 * torch.square(e2op1) + v1x * torch.square(e1xop1) + v2x * torch.square(e2xop1)
        var_out += (s41 * torch.square(e1op12) + s42 * torch.square(e2op12)) / 2
        var_out += (torch.square(v1x * e1xop12) + torch.square(v2x * e2xop12)) / 2
        var_out -= 2 * (self.var_1 * e1xop1 * e1op1 + self.var_2 * e2xop1 * e2op1 + var_x_expanded * e1xop1 * e2xop1)
        var_out -= s41 * e1xop12 * e1op12 + s42 * e2xop12 * e2op12 + torch.square(var_x_expanded) * e1xop12 * e2xop12
        
        # Sum the log Bayes factors and add bias
        mu_out_sum = mu_out.sum(dim=-1) + self.bias_mean
        var_out_sum = var_out.sum(dim=-1)
        # var_out_sum += self.bias_var
        
        return mu_out_sum, var_out_sum

class BayesFactorNetwork(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(BayesFactorNetwork, self).__init__()
        self.layers = nn.ModuleList()
        self.bias_mean = nn.Parameter(torch.zeros(1, output_size))
        self.bias_var = nn.Parameter(torch.ones(1, output_size))

        prev_size = input_size
        for hidden_size in hidden_sizes:
            self.layers.append(BayesFactorLayer(prev_size, hidden_size))
            prev_size = hidden_size
        
        self.output_layer = BayesFactorLayer(prev_size, output_size)
    
    def forward(self, mu_x, var_x):
        for layer in self.layers:
            mu_x, var_x = layer(mu_x, var_x)
        
        mu_out, var_out = self.output_layer(mu_x, var_x)
        mu_out += self.bias_mean
        var_out += self.bias_var
        return mu_out, var_out

# Example usage:
input_size = 10  # Number of input neurons
hidden_sizes = [20, 30]  # Sizes of hidden layers
output_size = 5  # Number of output neurons
batch_size = 2

# Create the network
model = BayesFactorNetwork(input_size, hidden_sizes, output_size)

# Example input: mean and variance of input neurons
mu_x = torch.randn(batch_size, input_size)
var_x = torch.ones(batch_size, input_size)

# Forward pass
mu_out, var_out = model(mu_x, var_x)
print("Output means:", mu_out)
print("Output variances:", var_out)

Output means: tensor([[ 0.9186,  0.9832, -1.5293, -0.3186, -0.6826],
        [ 0.9125,  0.6633, -1.8544, -0.8576, -0.7011]], grad_fn=<AddBackward0>)
Output variances: tensor([[43.2472, 48.7017, 46.1232, 42.9346, 49.7729],
        [42.1023, 47.7009, 43.7699, 43.6115, 50.1957]], grad_fn=<AddBackward0>)
