In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

In [2]:
from sklearn import datasets
from sklearn.model_selection import train_test_split

In [3]:
iris=datasets.load_iris()
xtrain,xtest,ytrain,ytest=train_test_split(iris.data,iris.target,test_size=0.4)

In [4]:
xtrain=torch.from_numpy(xtrain).type('torch.FloatTensor')
ytrain=torch.from_numpy(ytrain).type('torch.LongTensor')
xtest=torch.from_numpy(xtest).type('torch.FloatTensor')
ytest=torch.from_numpy(ytest).type('torch.LongTensor')

In [5]:
class MyIris(nn.Module):
  def __init__(self):
    super(MyIris,self).__init__()
    self.l1=nn.Linear(4,6)
    self.l2=nn.Linear(6,3)
  def forward(self,x):
    h1=torch.sigmoid(self.l1(x))
    h2=self.l2(h1)
    return h2

In [6]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
xtrain=xtrain.to(device)
ytrain=ytrain.to(device)
xtest=xtest.to(device)
ytest=ytest.to(device)

In [7]:
model=MyIris().to(device)
optimizer=optim.SGD(model.parameters(),lr=0.1)
criterion=nn.CrossEntropyLoss()

In [8]:
n=90
bs=25

model.train()
for i in range(1000):
  idx=np.random.permutation(n)
  for j in range(0,n,bs):
    xtm=xtrain[idx[j:(j+bs) if (j+bs) < n else n]]
    ytm=ytrain[idx[j:(j+bs) if (j+bs) < n else n]]
    output=model(xtm)
    loss=criterion(output,ytm)
    print(i,j,loss.item())
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

torch.save(model.state_dict(),'myiris.model')
model.load_state_dict(torch.load('myiris.model'))

model.eval()
with torch.no_grad():
  output=model(xtest)
  ans=torch.argmax(output,1)
  print(((ytest==ans).sum().float()/len(ans)).item())


0 0 1.4444714784622192
0 25 1.136846899986267
0 50 1.3294583559036255
0 75 1.1310521364212036
1 0 1.1436524391174316
1 25 1.1217418909072876
1 50 1.0592625141143799
1 75 1.1721745729446411
2 0 1.059639811515808
2 25 1.0656123161315918
2 50 1.1023539304733276
2 75 1.0667567253112793
3 0 1.0701749324798584
3 25 1.0490111112594604
3 50 1.0467737913131714
3 75 1.0398008823394775
4 0 1.0421384572982788
4 25 1.062664270401001
4 50 1.0356472730636597
4 75 1.0380651950836182
5 0 1.048407793045044
5 25 1.0280531644821167
5 50 1.0293248891830444
5 75 1.0300191640853882
6 0 1.019212007522583
6 25 1.0229343175888062
6 50 1.0259594917297363
6 75 1.0233250856399536
7 0 1.005251407623291
7 25 1.0249844789505005
7 50 1.015414834022522
7 75 1.0145474672317505
8 0 1.0194658041000366
8 25 0.9998025298118591
8 50 1.0336174964904785
8 75 1.000177264213562
9 0 0.9982579946517944
9 25 0.9953194260597229
9 50 0.9997897148132324
9 75 1.001670241355896
10 0 0.9676984548568726
10 25 1.0225011110305786
10 50 0.97