In [64]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [65]:
xy = np.loadtxt('data-04-zoo.csv', delimiter=',', dtype=np.float32)

x_train = torch.FloatTensor(xy[:, 0:-1])
# squeeze : 2d->1d
y_train = torch.LongTensor(xy[:, [-1]]).squeeze()

In [66]:
print(x_train[0, :], y_train[0])
print(x_train[1, :], y_train[1])
print(x_train[2, :], y_train[2])
print(x_train[3, :], y_train[3])

tensor([1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 4., 0., 0., 1.]) tensor(0)
tensor([1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 4., 1., 0., 1.]) tensor(0)
tensor([0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 1., 0., 1., 0., 0.]) tensor(3)
tensor([1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 0., 0., 4., 0., 0., 1.]) tensor(0)


In [67]:
# initialize
class SoftmaxClassifierModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(16, 7)

    def forward(self, x):
        return self.linear(x)

In [68]:
model = SoftmaxClassifierModel()

# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01)

In [69]:
nb_epochs = 1000
for epoch in range(nb_epochs + 1):

    # cost
    hypothesis = model(x_train)
    cost = F.cross_entropy(hypothesis, y_train)

    # gradient descent
    optimizer.zero_grad()
    cost.backward()
    optimizer.step()

    # check progress
    if epoch % 100 == 0:
        print(f'Epoch {epoch:4d}/{nb_epochs} Cost: {cost.item():.6f} ')

Epoch    0/1000 Cost: 1.810791 
Epoch  100/1000 Cost: 1.222991 
Epoch  200/1000 Cost: 0.971313 
Epoch  300/1000 Cost: 0.817205 
Epoch  400/1000 Cost: 0.714004 
Epoch  500/1000 Cost: 0.640063 
Epoch  600/1000 Cost: 0.584196 
Epoch  700/1000 Cost: 0.540192 
Epoch  800/1000 Cost: 0.504375 
Epoch  900/1000 Cost: 0.474456 
Epoch 1000/1000 Cost: 0.448938 


In [70]:
params = list(model.parameters())
W = params[0]
b = params[1]
print(W, b)

Parameter containing:
tensor([[ 0.7477, -0.1406, -1.0368,  0.7523, -0.2535, -0.3494, -0.0364,  0.5029,
          0.4708,  0.3080, -0.1459, -0.1576,  0.1145,  0.1353,  0.0825,  0.6486],
        [-0.0677,  0.8028,  0.2842, -0.3765,  0.4569, -0.2620, -0.0603, -0.6387,
          0.2466,  0.1568,  0.0593, -0.3840, -0.1508,  0.5774,  0.0949,  0.2553],
        [-0.0516, -0.0624,  0.0441, -0.1485, -0.1357, -0.3067, -0.0057,  0.1042,
          0.1894,  0.1823,  0.3353, -0.3243, -0.2362,  0.2487, -0.0964, -0.2570],
        [ 0.0018,  0.0187,  0.3793, -0.3161,  0.0714,  0.3134,  0.2509,  0.5267,
          0.0873, -0.5986, -0.1529,  0.6871, -0.6751,  0.0270, -0.1860,  0.1070],
        [-0.1698, -0.1618, -0.1305, -0.0414,  0.0747,  0.2947,  0.0022,  0.0506,
          0.0639, -0.0597,  0.2525, -0.3183,  0.0767, -0.4462, -0.1606, -0.1522],
        [ 0.0703, -0.0437, -0.0311,  0.0544,  0.1018, -0.2055, -0.4161, -0.4223,
         -0.4983, -0.1888,  0.0767, -0.2806,  0.4747, -0.4888, -0.1334, -0.1631],


In [73]:
with torch.no_grad():
    output = model(x_train)
    prediction = torch.argmax(output, 1)
    print(prediction)
    
    correct_prediction = (prediction == y_train)
    print(correct_prediction)
    
    accuracy = correct_prediction.float().mean()
    print('Accuracy:', accuracy.item())

tensor([0, 0, 3, 0, 0, 0, 0, 3, 3, 0, 0, 1, 3, 3, 6, 6, 1, 0, 3, 0, 1, 1, 0, 1,
        5, 4, 4, 0, 0, 0, 5, 0, 0, 1, 3, 0, 0, 1, 3, 5, 5, 1, 5, 1, 0, 0, 6, 0,
        0, 0, 0, 5, 0, 6, 0, 0, 1, 1, 1, 1, 3, 3, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1,
        5, 3, 0, 0, 3, 3, 1, 1, 3, 1, 3, 1, 0, 6, 3, 1, 5, 4, 1, 0, 3, 0, 0, 1,
        0, 5, 0, 1, 1])
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True, False,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True, False,  True,  True,  True, False, False,  True,  True,
        False, False,  True,  Tr