In [1]:
import numpy as np
from sklearn.datasets import load_iris

In [2]:
import torch
from torch import nn
from torch.autograd import Variable

In [3]:
# Load iris dataset
data = load_iris()

In [4]:
X = data.data[:]
y = data.target[:]

In [5]:
# shape of feature matrix

X.shape

(150, 4)

In [6]:
y.shape

(150,)

In [7]:
# number of unique classes

np.unique(y)

array([0, 1, 2])

In [8]:
# define Neural Network class

class LogisticRegression(nn.Module):
    def __init__(self, input_size, num_classes):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_size, num_classes)
    
    def forward(self, x):
        out = self.linear(x)
        return out

In [9]:
model = LogisticRegression(4, 3)

In [10]:
# https://pytorch.org/docs/stable/optim.html

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [11]:
# https://pytorch.org/docs/stable/nn.html

criterion = nn.CrossEntropyLoss()

In [12]:
n_epochs = 1000

for epoch in range(n_epochs):
    
    # convert features and target into PyTorch Variable
    inputs = Variable(torch.from_numpy(X).float())
    targets = Variable(torch.from_numpy(y))

    # forward pass
    outputs = model.forward(inputs)
    
    # convert predicted probabilities to class labels
    _, predicted_labels = torch.max(outputs.data, 1)

    # calculate loss (Cross-Entropy)
    loss = criterion(outputs, targets)
    
    # calculate number of correctly predicted data points
    corrects = torch.sum(predicted_labels == targets.data).item()
    
    # compute gradients
    loss.backward()
    
    # perform one step in the oposite direction to the gradient (update weights)
    optimizer.step()
    
    # clear gradient values after weights are updated
    optimizer.zero_grad()
    
    if epoch % 100 == 0:
        accuracy = corrects / len(X)
        print('epoch = {0}, loss = {1:.6f}, accuracy = {2:.3f}'.format(epoch, loss.item(), accuracy))

epoch = 0, loss = 2.473506, accuracy = 0.333
epoch = 100, loss = 0.512411, accuracy = 0.820
epoch = 200, loss = 0.397781, accuracy = 0.953
epoch = 300, loss = 0.328598, accuracy = 0.967
epoch = 400, loss = 0.277574, accuracy = 0.973
epoch = 500, loss = 0.238605, accuracy = 0.973
epoch = 600, loss = 0.208492, accuracy = 0.973
epoch = 700, loss = 0.184930, accuracy = 0.973
epoch = 800, loss = 0.166223, accuracy = 0.973
epoch = 900, loss = 0.151139, accuracy = 0.980
