In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
class CustomNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomNet, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        pi = self.softmax(x)  # Ensures sum over m of pi_{im}(x) equals 1 and each pi_{im}(x) is positive
        return pi

class CustomLikelihoodLoss(nn.Module):
    def __init__(self, L_im):
        super(CustomLikelihoodLoss, self).__init__()
        self.L_im = torch.tensor(L_im, dtype=torch.float32)
    
    def forward(self, pi, x):
        # pi has shape [batch_size, M]
        # L_im has shape [batch_size, M]
        
        # Compute the sum over m
        inner_sum = torch.sum(pi * self.L_im, dim=1)
        
        # Compute log and sum over i (batch)
        log_likelihood = torch.sum(torch.log(inner_sum))
        
        # Return negative log-likelihood as we want to minimize
        return -log_likelihood


In [3]:
input_size = 10
hidden_size = 20
output_size = 5  # This is M in your equation
batch_size = 32
np.random.seed(42)  
# Generate dummy data
X = torch.randn(batch_size, input_size)
L_im = np.random.rand(batch_size, output_size)  # This would be your actual L_im values


# Create model, loss function, and optimizer
model = CustomNet(input_size, hidden_size, output_size)
criterion = CustomLikelihoodLoss(L_im)
optimizer = optim.Adam(model.parameters(), lr=0.01)


In [4]:
print(X)

tensor([[ 7.8172e-01,  5.3189e-02, -2.1003e-01, -1.1114e+00,  3.1042e-01,
          1.3549e-01, -2.1518e+00,  5.7190e-01, -1.0991e+00, -1.3666e+00],
        [ 6.0167e-01,  8.6843e-02,  1.5716e+00, -1.6280e+00,  3.4600e-01,
          1.3615e-02, -1.3548e+00, -6.0066e-01,  9.8071e-01, -1.6869e+00],
        [ 1.8261e-01, -1.4915e+00,  1.7250e+00,  1.0730e+00,  1.5054e+00,
         -1.9959e-01,  3.5851e-01,  4.4212e-01, -1.6967e-01,  6.0200e-01],
        [ 1.6055e+00, -1.0208e-01,  5.5977e-01, -5.2612e-01,  7.0976e-01,
          4.6681e-01,  1.6503e+00, -1.4819e+00, -1.0760e+00,  2.3462e+00],
        [ 3.0298e-01,  3.6652e-01, -1.2854e+00,  1.9645e-01,  1.5433e+00,
         -4.8312e-01, -2.4922e+00, -2.8766e-01,  4.0799e-01, -4.7732e-01],
        [-3.7122e-01, -3.9824e-01, -1.5951e+00,  1.0349e+00, -3.5979e-01,
         -5.4680e-01, -9.6843e-01, -2.3286e-01,  1.2434e+00, -1.4911e+00],
        [-4.0563e-01,  2.2707e-01,  1.2148e+00,  3.3989e-01,  1.3016e-01,
         -1.0328e+00,  1.9271e+0

In [5]:
print(L_im)

[[0.37454012 0.95071431 0.73199394 0.59865848 0.15601864]
 [0.15599452 0.05808361 0.86617615 0.60111501 0.70807258]
 [0.02058449 0.96990985 0.83244264 0.21233911 0.18182497]
 [0.18340451 0.30424224 0.52475643 0.43194502 0.29122914]
 [0.61185289 0.13949386 0.29214465 0.36636184 0.45606998]
 [0.78517596 0.19967378 0.51423444 0.59241457 0.04645041]
 [0.60754485 0.17052412 0.06505159 0.94888554 0.96563203]
 [0.80839735 0.30461377 0.09767211 0.68423303 0.44015249]
 [0.12203823 0.49517691 0.03438852 0.9093204  0.25877998]
 [0.66252228 0.31171108 0.52006802 0.54671028 0.18485446]
 [0.96958463 0.77513282 0.93949894 0.89482735 0.59789998]
 [0.92187424 0.0884925  0.19598286 0.04522729 0.32533033]
 [0.38867729 0.27134903 0.82873751 0.35675333 0.28093451]
 [0.54269608 0.14092422 0.80219698 0.07455064 0.98688694]
 [0.77224477 0.19871568 0.00552212 0.81546143 0.70685734]
 [0.72900717 0.77127035 0.07404465 0.35846573 0.11586906]
 [0.86310343 0.62329813 0.33089802 0.06355835 0.31098232]
 [0.32518332 0

In [6]:

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    # Forward pass
    pi = model(X)
    loss = criterion(pi, X)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
 

Epoch [10/100], Loss: 21.1211
Epoch [20/100], Loss: 17.3898
Epoch [30/100], Loss: 13.8663
Epoch [40/100], Loss: 11.2373
Epoch [50/100], Loss: 9.9291
Epoch [60/100], Loss: 9.0841
Epoch [70/100], Loss: 8.7452
Epoch [80/100], Loss: 8.5897
Epoch [90/100], Loss: 8.3875
Epoch [100/100], Loss: 8.3368


In [7]:
model.eval()
with torch.no_grad():
    fitted_pi = model(X)
    print("\nFitted pi values for each row of L_im:")
    for i in range(batch_size):
        print(f"Row {i+1}:")
        print(f"L_im: {L_im[i]}")
        print(f"Fitted pi: {fitted_pi[i].numpy()}")
        print(f"Sum of fitted pi: {fitted_pi[i].sum().item():.4f}")
        print()

# Print a summary of pi values
print("Summary of fitted pi values:")
print(f"Mean: {fitted_pi.mean(dim=0).numpy()}")
print(f"Min: {fitted_pi.min(dim=0)[0].numpy()}")
print(f"Max: {fitted_pi.max(dim=0)[0].numpy()}") 


Fitted pi values for each row of L_im:
Row 1:
L_im: [0.37454012 0.95071431 0.73199394 0.59865848 0.15601864]
Fitted pi: [3.1463206e-03 9.9523711e-01 1.4966888e-03 8.5036416e-05 3.4706241e-05]
Sum of fitted pi: 1.0000

Row 2:
L_im: [0.15599452 0.05808361 0.86617615 0.60111501 0.70807258]
Fitted pi: [1.1397196e-03 3.5853568e-04 9.9848652e-01 1.0370062e-05 4.8927418e-06]
Sum of fitted pi: 1.0000

Row 3:
L_im: [0.02058449 0.96990985 0.83244264 0.21233911 0.18182497]
Fitted pi: [1.6097986e-04 9.9163967e-01 4.4779363e-04 1.7617054e-05 7.7339378e-03]
Sum of fitted pi: 1.0000

Row 4:
L_im: [0.18340451 0.30424224 0.52475643 0.43194502 0.29122914]
Fitted pi: [2.3987794e-03 1.1582412e-04 9.9743098e-01 1.8259969e-07 5.4241842e-05]
Sum of fitted pi: 1.0000

Row 5:
L_im: [0.61185289 0.13949386 0.29214465 0.36636184 0.45606998]
Fitted pi: [9.9214184e-01 3.7237371e-03 1.3143159e-05 3.2847980e-05 4.0884828e-03]
Sum of fitted pi: 1.0000

Row 6:
L_im: [0.78517596 0.19967378 0.51423444 0.59241457 0.04645

In [8]:
fitted_pi

tensor([[3.1463e-03, 9.9524e-01, 1.4967e-03, 8.5036e-05, 3.4706e-05],
        [1.1397e-03, 3.5854e-04, 9.9849e-01, 1.0370e-05, 4.8927e-06],
        [1.6098e-04, 9.9164e-01, 4.4779e-04, 1.7617e-05, 7.7339e-03],
        [2.3988e-03, 1.1582e-04, 9.9743e-01, 1.8260e-07, 5.4242e-05],
        [9.9214e-01, 3.7237e-03, 1.3143e-05, 3.2848e-05, 4.0885e-03],
        [9.9812e-01, 1.5631e-06, 1.6827e-05, 1.6058e-06, 1.8578e-03],
        [6.2040e-03, 2.6723e-04, 3.9094e-05, 1.9136e-05, 9.9347e-01],
        [9.9379e-01, 1.2566e-03, 2.9916e-03, 8.9599e-05, 1.8754e-03],
        [5.1925e-03, 9.9480e-01, 1.0219e-06, 1.6207e-06, 8.5646e-07],
        [3.8744e-07, 3.0041e-07, 1.0000e+00, 3.2798e-09, 5.6627e-09],
        [9.9941e-01, 2.6653e-05, 4.8146e-04, 1.9190e-06, 7.5516e-05],
        [9.9617e-01, 1.2576e-04, 1.8163e-04, 3.3744e-05, 3.4869e-03],
        [9.7266e-01, 2.1215e-02, 5.9079e-03, 1.1788e-04, 9.6368e-05],
        [6.9845e-04, 1.2188e-04, 6.6762e-03, 4.3927e-06, 9.9250e-01],
        [9.7929e-01,