In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

iris = load_iris()
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

class SimpleANN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleANN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

input_size = X_train.shape[1]
hidden_size = 10
output_size = len(set(y_train))

model = SimpleANN(input_size, hidden_size, output_size)

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train, dtype=torch.long)

num_epochs = 100
for epoch in range(num_epochs):
    outputs = model(X_train_tensor)

    loss = criterion(outputs, y_train_tensor)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

with torch.no_grad():
    X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
    outputs = model(X_test_tensor)
    _, predicted = torch.max(outputs, 1)
    accuracy = (predicted == torch.tensor(y_test, dtype=torch.long)).float().sum().item() / len(y_test)
    print(f'Test Accuracy: {accuracy:.4f}')

Epoch [10/100], Loss: 1.2130
Epoch [20/100], Loss: 1.1779
Epoch [30/100], Loss: 1.1479
Epoch [40/100], Loss: 1.1219
Epoch [50/100], Loss: 1.0991
Epoch [60/100], Loss: 1.0786
Epoch [70/100], Loss: 1.0593
Epoch [80/100], Loss: 1.0412
Epoch [90/100], Loss: 1.0236
Epoch [100/100], Loss: 1.0062
Test Accuracy: 0.8000
