In [1]:
import torch
import torch.nn as nn
import torch.optim as optim


This code defines a class called `PoissonRegression` which is a subclass of the `nn.Module` class from the PyTorch library. The purpose of this class is to implement a simple Poisson regression model.

In the `__init__` method (the constructor), the code takes an `input_dim` parameter, which represents the number of features in the input data. It then calls the `__init__` method of the parent class `nn.Module` using `super(PoissonRegression, self).__init__()` to initialize the parent class. After that, it creates a linear layer using `self.linear = nn.Linear(input_dim, 1)`. This linear layer maps the input data to a single output value.

The `forward` method is the main computation of the model. It takes an input `x` and applies the linear transformation to it using `self.linear(x)`. The result is then passed through the exponential function `torch.exp()` to ensure that the output is positive. Finally, the result is returned.

To summarize, this code defines a Poisson regression model using a single linear layer. The `forward` method performs the forward pass of the model, transforming the input data and returning the predicted output.

In [2]:
class PoissonRegression(nn.Module):
    def __init__(self, input_dim):
        super(PoissonRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def forward(self, x):
        return torch.exp(self.linear(x))


generate data

In [3]:
import numpy as np
import random

# Generate the artificial data

n_total = int(20000)
K = 100
X = np.column_stack((np.ones(n_total), np.random.rand(n_total * (K - 1)).reshape(n_total, K - 1)))
b0 = 1.0 * np.ones(K) / K
y = np.random.poisson(np.exp(X.dot(b0)))


In [4]:
# Split the data into train and test sets
n = len(y)
test_indices = random.sample(range(n), round(0.2 * n))

y_test = y[test_indices]
X_test = X[test_indices, :]

y_train = np.delete(y, test_indices)
X_train = np.delete(X, test_indices, axis=0)

# convert X_train into a pytorch tensor
X_train = torch.from_numpy(X_train).float()
y_train = torch.from_numpy(y_train).float()

In [7]:
model = PoissonRegression(K)


In [6]:
def poisson_nll_loss(y_pred, y_true):
    return torch.mean(y_pred - y_true * torch.log(y_pred)) # this is the correct formula. 
# y_pred is the linear index, y_true is the true value.



In [9]:
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in range(100):  # number of epochs
    optimizer.zero_grad()
    y_pred = model(X_train) # the liner index
    loss = poisson_nll_loss(y_pred, y_train)
    loss.backward()
    optimizer.step()

    if epoch % 1 == 0:  # print every 100 epochs
        print(f'Epoch {epoch+1}, Loss: {loss.item()}')


Epoch 1, Loss: 0.9684693813323975
Epoch 2, Loss: 0.860050618648529
Epoch 3, Loss: 0.8407831788063049
Epoch 4, Loss: 0.8360377550125122
Epoch 5, Loss: 0.8347012400627136
Epoch 6, Loss: 0.8342761397361755
Epoch 7, Loss: 0.8341097235679626
Epoch 8, Loss: 0.834019124507904
Epoch 9, Loss: 0.8339512944221497
Epoch 10, Loss: 0.8338902592658997
Epoch 11, Loss: 0.8338314890861511
Epoch 12, Loss: 0.8337734341621399
Epoch 13, Loss: 0.8337157964706421
Epoch 14, Loss: 0.8336582779884338
Epoch 15, Loss: 0.8336010575294495
Epoch 16, Loss: 0.8335438966751099
Epoch 17, Loss: 0.8334869742393494
Epoch 18, Loss: 0.8334301114082336
Epoch 19, Loss: 0.833373486995697
Epoch 20, Loss: 0.83331698179245
Epoch 21, Loss: 0.833260715007782
Epoch 22, Loss: 0.8332045078277588
Epoch 23, Loss: 0.8331485390663147
Epoch 24, Loss: 0.8330925703048706
Epoch 25, Loss: 0.8330370187759399
Epoch 26, Loss: 0.832981526851654
Epoch 27, Loss: 0.8329259753227234
Epoch 28, Loss: 0.8328708410263062
Epoch 29, Loss: 0.8328157067298889
E

In [10]:
# print the coefficients of the model

# print(model.state_dict()['linear.weight'])
print(model.state_dict())

OrderedDict([('linear.weight', tensor([[ 0.0834,  0.0562,  0.0690,  0.0345, -0.0198,  0.0300,  0.0692,  0.0230,
         -0.0527, -0.0084, -0.0546, -0.0024, -0.0549, -0.0561, -0.0388,  0.0627,
         -0.0624,  0.0103, -0.0135,  0.0722, -0.0324, -0.0478,  0.0456, -0.0374,
         -0.0418,  0.0597, -0.0756, -0.0262,  0.0471, -0.0103, -0.0398,  0.0803,
          0.0713, -0.0195, -0.0448, -0.0580,  0.0692,  0.0129, -0.0729,  0.0359,
          0.0660, -0.0358,  0.0552,  0.0549, -0.0550,  0.0650,  0.0769, -0.0382,
          0.0744,  0.0161, -0.0114,  0.0550,  0.0430, -0.0277, -0.0782, -0.0500,
         -0.0491,  0.0586,  0.0573, -0.0520,  0.0633, -0.0037, -0.0186,  0.0366,
          0.0380,  0.0266,  0.0156,  0.0066,  0.0264,  0.0687,  0.0509,  0.0093,
          0.0486, -0.0131, -0.0198, -0.0573, -0.0122,  0.0455, -0.0352, -0.0093,
          0.0741, -0.0682, -0.0738, -0.0006, -0.0688, -0.0211,  0.0722,  0.0141,
          0.0251,  0.0700,  0.0525,  0.0202, -0.0202,  0.0702, -0.0528, -0.018