# model()

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split

# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# convert NumPy array into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# split
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7, shuffle=True)

# PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

model = Multiclass()
    
# loss metric and optimizer
loss_fn = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# prepare model and training parameters
n_epochs = 100
batch_size = 5
batch_start = torch.arange(0, len(X), batch_size)

# training loop
for epoch in range(n_epochs):
    for start in batch_start:
        # take a batch
        X_batch = X_train[start:start+batch_size]
        y_batch = y_train[start:start+batch_size]
        # forward pass
        y_pred = model(X_batch)
        loss = loss_fn(y_pred, y_batch)
        # backward pass
        optimizer.zero_grad()
        loss.backward()
        # update weights
        optimizer.step()

# Save model
torch.save(model.state_dict(), "iris-model.pth")

# how to load the ready-made pre-training parameters model to obtain the final result

In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split


# Load data into NumPy arrays
data = load_iris()
X, y = data["data"], data["target"]

# convert NumPy array into PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# PyTorch model
class Multiclass(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(4, 8)
        self.act = nn.ReLU()
        self.output = nn.Linear(8, 3)
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = self.act(self.hidden(x))
        x = self.logsoftmax(self.output(x))
        return x

# Create new model and load states
model = Multiclass()
model.load_state_dict(torch.load('iris-model.pth'))
    
print(torch.load('iris-model.pth'))
print("--------------------------------------------")

# Run model for inference
''' 
Because the class Multiclass() has already been instantiated, 
the following parentheses `()` operations will directly call the function of forward()
'''
y_pred = model(X_test)
print("This is the y_pred: ",y_pred.shape,y_pred)
print("This is the y_test:",y_test.shape,y_test)
print("-----------------------------------------")
print(torch.argmax(y_pred, 1) == y_test)
print((torch.argmax(y_pred, 1) == y_test).float())
acc = (torch.argmax(y_pred, 1) == y_test).float().mean()
print("Accuracy: %.8f" % acc)

OrderedDict([('hidden.weight', tensor([[-0.2543, -0.1561, -0.4526,  0.4101],
        [-0.2878,  0.1367, -0.3660,  0.0995],
        [ 0.0171, -0.4326,  0.9949,  1.4759],
        [-0.3806, -0.3014,  0.3371, -0.1611],
        [ 0.5568,  1.2650, -0.1330, -1.4432],
        [ 0.0568, -0.0711, -0.4397,  0.2276],
        [-0.2176, -0.2337,  0.0726,  0.1096],
        [ 0.1246, -0.5134,  1.0351,  0.6406]])), ('hidden.bias', tensor([-0.3269,  0.4239, -1.0601,  0.1092,  0.8992,  0.3887, -0.2973, -0.6241])), ('output.weight', tensor([[ 0.0760, -0.0117, -0.6924, -0.2346,  0.6248, -0.0626, -0.2156, -1.1969],
        [-0.0073, -0.0147,  0.0547,  0.2712,  0.0677, -0.1849,  0.3444,  0.1543],
        [-0.0563,  0.2219,  0.7317,  0.1037, -0.7972,  0.2775,  0.0540,  0.4696]])), ('output.bias', tensor([ 0.5954,  0.3059, -0.3317]))])
--------------------------------------------
This is the y_pred:  torch.Size([45, 3]) tensor([[-1.2762e+01, -2.1240e+00, -1.2733e-01],
        [-1.8920e-02, -3.9783e+00, -1.0611