In [1]:
import torch
import torch.nn as nn
import torch.optim as optim


In [2]:
class PoissonRegression(nn.Module):
    def __init__(self, input_dim):
        super(PoissonRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.exp(self.linear(x))


generate data

In [3]:
import numpy as np
import random

# Generate the artificial data
# random.seed(898)
n_total = int(10000)
K = 500
X = np.column_stack((np.ones(n_total), np.random.rand(n_total * (K - 1)).reshape(n_total, K - 1)))
b0 = 0.0 * np.ones(K) / K
y = np.random.poisson(np.exp(X.dot(b0)))


In [4]:
# Split the data into train and test sets
n = len(y)
test_indices = random.sample(range(n), round(0.2 * n))

y_test = y[test_indices]
X_test = X[test_indices, :]

y_train = np.delete(y, test_indices)
X_train = np.delete(X, test_indices, axis=0)

# convert X_train into a pytorch tensor
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float()

In [5]:
model = PoissonRegression(K)
optimizer = optim.SGD(model.parameters(), lr=1)


In [6]:
def poisson_nll_loss(y_pred, y_true):
    return torch.mean(y_pred - y_true * torch.log(y_pred)) # this is the correct formula. 
# y_pred is the linear index, y_true is the true value.




In [7]:
for epoch in range(10):  # number of epochs
    optimizer.zero_grad()
    y_pred = model(X_train) # the liner index
    loss = poisson_nll_loss(y_pred, y_train)
    loss.backward()
    optimizer.step()

    if epoch % 1 == 0:  # print every 100 epochs
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')


Epoch 1, Loss: 1.018455147743225
Epoch 2, Loss: 12.277788162231445
Epoch 3, Loss: nan
Epoch 4, Loss: nan
Epoch 5, Loss: nan
Epoch 6, Loss: nan
Epoch 7, Loss: nan
Epoch 8, Loss: nan
Epoch 9, Loss: nan
Epoch 10, Loss: nan


In [8]:
# print the coefficients of the model

# print(model.state_dict()['linear.weight'])
print(model.state_dict())

OrderedDict([('linear.weight', tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, na