In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [6]:
class CustomNet(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(CustomNet, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        pi = self.softmax(x)  # Ensures sum over m of pi_{im}(x) equals 1 and each pi_{im}(x) is positive
        return pi

class CustomLikelihoodLoss(nn.Module):
    def __init__(self, L_im):
        super(CustomLikelihoodLoss, self).__init__()
        self.L_im = torch.tensor(L_im, dtype=torch.float32)
    
    def forward(self, pi, x):
        # pi has shape [batch_size, M]
        # L_im has shape [batch_size, M]
        
        # Compute the sum over m
        inner_sum = torch.sum(pi * self.L_im, dim=1)
        
        # Compute log and sum over i (batch)
        log_likelihood = torch.sum(torch.log(inner_sum))
        
        # Return negative log-likelihood as we want to minimize
        return -log_likelihood


In [13]:
input_size = 10
hidden_size = 20
output_size = 5  # This is M in your equation
batch_size = 32
np.random.seed(42)  
# Generate dummy data
X = torch.randn(batch_size, input_size)
L_im = np.random.rand(batch_size, output_size)  # This would be your actual L_im values


# Create model, loss function, and optimizer
model = CustomNet(input_size, hidden_size, output_size)
criterion = CustomLikelihoodLoss(L_im)
optimizer = optim.Adam(model.parameters(), lr=0.01)


In [14]:
print(X)

tensor([[-1.3484, -0.2640,  0.9099, -1.5606,  1.4378,  0.0040, -1.4830,  0.8966,
         -0.6763, -0.4695],
        [-0.5849, -0.4338, -1.1058, -0.9041,  1.6045,  1.0774, -0.5318,  0.6545,
         -0.0977,  1.2266],
        [ 0.7214,  1.1877, -0.0591, -0.2207,  0.3404,  0.0862, -0.6847,  2.3038,
         -0.4819,  0.0228],
        [-0.1568, -0.5630, -0.7536,  0.2923,  1.2674,  0.0224, -0.3978, -1.5272,
          0.1080,  0.4921],
        [ 1.8564, -0.1553, -0.8271,  0.5344, -0.2161, -0.2129,  1.1862,  0.9169,
         -0.0102, -0.8224],
        [ 0.0225, -0.7149,  0.9777, -0.5788, -0.0807,  0.7151, -0.5676, -0.9079,
         -0.3319,  1.0973],
        [-0.5677, -1.0181, -1.1159,  0.3624, -0.7750,  1.2938,  0.6408, -0.8450,
          0.0358, -1.0565],
        [-1.0109,  0.7868, -0.1744,  0.4529, -0.1628, -0.0193,  1.8662,  0.6820,
          0.2044,  1.6355],
        [-0.3323,  1.3491, -0.0715,  1.2324, -0.5957, -0.5797, -0.5425,  0.4078,
         -1.0083, -0.3317],
        [-2.2853, -

In [15]:
print(L_im)

[[0.26668569 0.78754869 0.67316878 0.64544026 0.69891338]
 [0.64933105 0.37334476 0.05091967 0.96732522 0.35050185]
 [0.2825673  0.64677301 0.20374292 0.61359931 0.90011016]
 [0.24479624 0.66402023 0.50719465 0.72633718 0.50487767]
 [0.26082044 0.70374839 0.99138892 0.80076073 0.7538149 ]
 [0.77248328 0.05536853 0.40989202 0.33879961 0.94302775]
 [0.60436542 0.48822427 0.90328979 0.69203341 0.14293148]
 [0.45516054 0.13671865 0.08645839 0.94669641 0.11910918]
 [0.78748762 0.97210374 0.1005992  0.6977242  0.38472298]
 [0.22061274 0.71521998 0.36114728 0.87919839 0.91122006]
 [0.97066013 0.91793721 0.15983995 0.20780622 0.951308  ]
 [0.50544281 0.46120077 0.09785864 0.00143164 0.68084822]
 [0.11172724 0.97841989 0.8927301  0.65654211 0.97026546]
 [0.36802337 0.88772818 0.36086246 0.89516289 0.96831611]
 [0.76528408 0.96614103 0.30983873 0.47743863 0.23390017]
 [0.07727259 0.915316   0.64367995 0.39141778 0.21471649]
 [0.56802206 0.73179308 0.09549573 0.35435044 0.3368751 ]
 [0.42999903 0

In [16]:

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    # Forward pass
    pi = model(X)
    loss = criterion(pi, X)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
 

Epoch [10/100], Loss: 18.9447
Epoch [20/100], Loss: 14.8255
Epoch [30/100], Loss: 12.2333
Epoch [40/100], Loss: 10.3567
Epoch [50/100], Loss: 9.2267
Epoch [60/100], Loss: 8.5829
Epoch [70/100], Loss: 8.1924
Epoch [80/100], Loss: 7.9631
Epoch [90/100], Loss: 7.7877
Epoch [100/100], Loss: 7.6098
Model prediction (pi values):
[[0.45007452 0.07221188 0.00194814 0.37909746 0.09666806]]
Sum of pi values: 1.0


In [18]:
model.eval()
with torch.no_grad():
    fitted_pi = model(X)
    print("\nFitted pi values for each row of L_im:")
    for i in range(batch_size):
        print(f"Row {i+1}:")
        print(f"L_im: {L_im[i]}")
        print(f"Fitted pi: {fitted_pi[i].numpy()}")
        print(f"Sum of fitted pi: {fitted_pi[i].sum().item():.4f}")
        print()

# Print a summary of pi values
print("Summary of fitted pi values:")
print(f"Mean: {fitted_pi.mean(dim=0).numpy()}")
print(f"Min: {fitted_pi.min(dim=0)[0].numpy()}")
print(f"Max: {fitted_pi.max(dim=0)[0].numpy()}") 


Fitted pi values for each row of L_im:
Row 1:
L_im: [0.26668569 0.78754869 0.67316878 0.64544026 0.69891338]
Fitted pi: [7.5818819e-04 5.8049809e-06 1.1075619e-08 1.2014237e-05 9.9922395e-01]
Sum of fitted pi: 1.0000

Row 2:
L_im: [0.64933105 0.37334476 0.05091967 0.96732522 0.35050185]
Fitted pi: [2.0695526e-02 5.4418647e-06 3.6365960e-05 9.7259688e-01 6.6657299e-03]
Sum of fitted pi: 1.0000

Row 3:
L_im: [0.2825673  0.64677301 0.20374292 0.61359931 0.90011016]
Fitted pi: [5.8365888e-03 5.7653352e-03 2.0842360e-04 5.8850567e-03 9.8230457e-01]
Sum of fitted pi: 1.0000

Row 4:
L_im: [0.24479624 0.66402023 0.50719465 0.72633718 0.50487767]
Fitted pi: [2.9481018e-03 9.2029715e-01 6.5996574e-04 1.7363900e-02 5.8730867e-02]
Sum of fitted pi: 1.0000

Row 5:
L_im: [0.26082044 0.70374839 0.99138892 0.80076073 0.7538149 ]
Fitted pi: [0.06132855 0.20189281 0.2841508  0.16127539 0.29135242]
Sum of fitted pi: 1.0000

Row 6:
L_im: [0.77248328 0.05536853 0.40989202 0.33879961 0.94302775]
Fitted pi:

In [19]:
fitted_pi

tensor([[7.5819e-04, 5.8050e-06, 1.1076e-08, 1.2014e-05, 9.9922e-01],
        [2.0696e-02, 5.4419e-06, 3.6366e-05, 9.7260e-01, 6.6657e-03],
        [5.8366e-03, 5.7653e-03, 2.0842e-04, 5.8851e-03, 9.8230e-01],
        [2.9481e-03, 9.2030e-01, 6.5997e-04, 1.7364e-02, 5.8731e-02],
        [6.1329e-02, 2.0189e-01, 2.8415e-01, 1.6128e-01, 2.9135e-01],
        [2.9049e-01, 2.2976e-03, 2.0934e-04, 2.0980e-04, 7.0680e-01],
        [3.5472e-02, 4.9728e-03, 8.1414e-01, 1.4139e-01, 4.0216e-03],
        [8.0006e-03, 4.2458e-04, 2.2678e-03, 9.8864e-01, 6.6655e-04],
        [8.0134e-06, 9.9774e-01, 2.0779e-05, 6.0828e-05, 2.1687e-03],
        [3.5087e-04, 5.8416e-02, 4.3485e-05, 6.2941e-02, 8.7825e-01],
        [1.5127e-07, 9.9865e-01, 3.1229e-07, 4.4021e-07, 1.3466e-03],
        [9.6058e-03, 6.1944e-04, 1.7580e-03, 2.3468e-04, 9.8778e-01],
        [1.2933e-04, 1.2404e-02, 2.5156e-06, 1.7062e-06, 9.8746e-01],
        [4.8673e-04, 9.2949e-01, 2.8114e-03, 2.5316e-03, 6.4682e-02],
        [1.0602e-02,