<a href="https://colab.research.google.com/github/rtarun1/pytorch-01/blob/main/pytorch02iris.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [46]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

In [47]:
# Create a Model Class that inherits nn.Module

class Model(nn.Module):
  #Input layer (4 features of the flower) Hidden layer -- H2 -- Output
  def __init__(self, in_features = 4, h1 = 8, h2 = 9, out_features = 3):
    super().__init__()
    self.fc1 = nn.Linear(in_features, h1)
    self.fc2 = nn.Linear(h1, h2)
    self.out = nn.Linear(h2, out_features)

  def forward(self, x):
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = self.out(x)

    return x

In [48]:
torch.manual_seed(69)

model = Model()

In [49]:
url = 'https://gist.githubusercontent.com/curran/a08a1080b88344b0c8a7/raw/0e7a9b0a5d22642a06d3d5b9bcbad9890c8ee534/iris.csv'
my_df = pd.read_csv(url)
my_df

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,setosa
1,4.9,3.0,1.4,0.2,setosa
2,4.7,3.2,1.3,0.2,setosa
3,4.6,3.1,1.5,0.2,setosa
4,5.0,3.6,1.4,0.2,setosa
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,virginica
146,6.3,2.5,5.0,1.9,virginica
147,6.5,3.0,5.2,2.0,virginica
148,6.2,3.4,5.4,2.3,virginica


In [50]:
my_df['species'] = my_df['species'].replace('setosa', 0.0)
my_df['species'] = my_df['species'].replace('versicolor', 1.0)
my_df['species'] = my_df['species'].replace('virginica', 2.0)
my_df

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width,species
0,5.1,3.5,1.4,0.2,0.0
1,4.9,3.0,1.4,0.2,0.0
2,4.7,3.2,1.3,0.2,0.0
3,4.6,3.1,1.5,0.2,0.0
4,5.0,3.6,1.4,0.2,0.0
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,2.0
146,6.3,2.5,5.0,1.9,2.0
147,6.5,3.0,5.2,2.0,2.0
148,6.2,3.4,5.4,2.3,2.0


In [51]:
X = my_df.drop('species', axis = 1)
y = my_df['species']
X = X.values
y = y.values

In [52]:
from sklearn.model_selection import train_test_split

In [53]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 69)

In [54]:
X_train = torch.FloatTensor(X_train)
X_test = torch.FloatTensor(X_test)

In [55]:
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

In [56]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.01)

In [57]:
epochs = 500
losses = []
for i in range(epochs):
  y_pred = model.forward(X_train)

  loss = criterion(y_pred, y_train)

  losses.append(loss.detach().numpy())

  if i % 10 == 0:
    print(f'Epoch: {i} and Losses: {loss}')

  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

Epoch: 0 and Losses: 1.0933918952941895
Epoch: 10 and Losses: 0.8141745328903198
Epoch: 20 and Losses: 0.5294780135154724
Epoch: 30 and Losses: 0.4019189774990082
Epoch: 40 and Losses: 0.2870267331600189
Epoch: 50 and Losses: 0.18031372129917145
Epoch: 60 and Losses: 0.11969298869371414
Epoch: 70 and Losses: 0.092463918030262
Epoch: 80 and Losses: 0.07874865084886551
Epoch: 90 and Losses: 0.07112809270620346
Epoch: 100 and Losses: 0.06658605486154556
Epoch: 110 and Losses: 0.06364356726408005
Epoch: 120 and Losses: 0.06154729798436165
Epoch: 130 and Losses: 0.059934351593256
Epoch: 140 and Losses: 0.05868910253047943
Epoch: 150 and Losses: 0.05789446085691452
Epoch: 160 and Losses: 0.05688353255391121
Epoch: 170 and Losses: 0.056084588170051575
Epoch: 180 and Losses: 0.05547191575169563
Epoch: 190 and Losses: 0.05491584539413452
Epoch: 200 and Losses: 0.054414138197898865
Epoch: 210 and Losses: 0.05396454036235809
Epoch: 220 and Losses: 0.05355878546833992
Epoch: 230 and Losses: 0.0531

In [58]:
with torch.no_grad():
  y_eval = model.forward(X_test)
  loss = criterion(y_eval, y_test)

In [59]:
loss


tensor(0.0101)

In [60]:
correct = 0
with torch.no_grad():
  for i, data in enumerate(X_test):
    y_val = model.forward(data)
    print(f'{i+1}.) {str(y_val)} \t {y_test[i]}')

    if y_val.argmax().item() == y_test[i]:
      correct +=1

  print (correct)


1.) tensor([ 14.2881,   3.8119, -35.1578]) 	 0
2.) tensor([-5.1651,  7.4313, -2.2603]) 	 1
3.) tensor([ 13.3799,   4.7377, -34.3573]) 	 0
4.) tensor([-12.7828,   3.9416,  15.6897]) 	 2
5.) tensor([ 13.1734,   4.2853, -33.4004]) 	 0
6.) tensor([ 15.9889,   3.7030, -38.8002]) 	 0
7.) tensor([-4.5628,  6.5966, -1.9545]) 	 1
8.) tensor([-14.1027,   5.4538,  15.6248]) 	 2
9.) tensor([-10.9480,   6.8164,   8.3986]) 	 2
10.) tensor([ 15.6365,   3.8081, -38.1198]) 	 0
11.) tensor([-9.6066,  5.7392,  7.9149]) 	 2
12.) tensor([-15.8319,   5.6045,  18.1286]) 	 2
13.) tensor([ 12.2905,   3.9282, -30.8659]) 	 0
14.) tensor([-5.7617,  6.6518, -0.1384]) 	 1
15.) tensor([-12.8851,   4.9602,  14.4000]) 	 2
16.) tensor([-5.4306,  7.1298, -1.3166]) 	 1
17.) tensor([-1.6134,  7.1897, -7.7771]) 	 1
18.) tensor([-13.9625,   4.8033,  16.2999]) 	 2
19.) tensor([-3.3083,  7.1545, -4.9575]) 	 1
20.) tensor([-11.7816,   4.4841,  13.2609]) 	 2
21.) tensor([-14.6706,   4.2632,  18.2659]) 	 2
22.) tensor([-2.3565, 