In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score

In [2]:
# Load Iris data
iris = load_iris()
X = iris.data
y = iris.target

# Normalize features
scaler = StandardScaler()
X = scaler.fit_transform(X)

# Convert to PyTorch tensors
X = torch.tensor(X, dtype=torch.float32)
y = torch.tensor(y, dtype=torch.long)

# Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [3]:
class IrisNet(nn.Module):
    def __init__(self):
        super(IrisNet, self).__init__()
        self.fc1 = nn.Linear(4, 16)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(16, 3)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

#### In the criterion below add Cross Entropy Loss
#### In the optimizer add the parameters of the model
#### Add learning rate (alpha) as 0.01
#### Add momentum as 0.9
#### Enable the nesterov function

In [None]:
model = IrisNet()
criterion = 
optimizer = optim.SGD()

In [5]:
epochs = 100
for epoch in range(epochs):
    model.train()
    
    outputs = model(X_train)
    loss = criterion(outputs, y_train)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}")

Epoch [10/100], Loss: 0.9120
Epoch [20/100], Loss: 0.6692
Epoch [30/100], Loss: 0.5267
Epoch [40/100], Loss: 0.4499
Epoch [50/100], Loss: 0.4022
Epoch [60/100], Loss: 0.3673
Epoch [70/100], Loss: 0.3389
Epoch [80/100], Loss: 0.3140
Epoch [90/100], Loss: 0.2913
Epoch [100/100], Loss: 0.2703


In [6]:
model.eval()
with torch.no_grad():
    predictions = model(X_test)
    _, predicted_classes = torch.max(predictions, 1)
    acc = accuracy_score(y_test, predicted_classes)
    print(f"Test Accuracy: {acc * 100:.2f}%")

Test Accuracy: 93.33%
