In [2]:
import torch
from torch import nn

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
x = torch.FloatTensor([
  [0,0],
  [0,1],
  [1,0],
  [1,1],
]).to(device)
y = torch.FloatTensor([[0],[1],[1],[0],]).to(device)

In [5]:
model = nn.Sequential(
  nn.Linear(2, 10),
  nn.Sigmoid(),
  nn.Linear(10, 10),
  nn.Sigmoid(),
  nn.Linear(10, 10),
  nn.Sigmoid(),
  nn.Linear(10, 1),
  nn.Sigmoid(),
).to(device)

In [8]:
criterion = nn.BCELoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1)

In [10]:
for epoch in range(10001):
  optimizer.zero_grad()
  hypothesis = model(x)

  cost = criterion(hypothesis, y)
  cost.backward()
  optimizer.step()

  if epoch%1000 == 0:
    print(f'epoch: {epoch}, cost: {cost:>0.4}')
    torch.save(model, f'mymodel_{epoch:04d}.pt')

epoch: 0, cost: 0.0001444
epoch: 1000, cost: 0.000117
epoch: 2000, cost: 9.811e-05
epoch: 3000, cost: 8.436e-05
epoch: 4000, cost: 7.393e-05
epoch: 5000, cost: 6.575e-05
epoch: 6000, cost: 5.916e-05
epoch: 7000, cost: 5.374e-05
epoch: 8000, cost: 4.92e-05
epoch: 9000, cost: 4.538e-05
epoch: 10000, cost: 4.21e-05


In [12]:
with torch.no_grad():
  hypothesis = model(x)
  predicted = (hypothesis > 0.5).float()
  accuracy = (predicted == y).float().mean()
  print(f'hypothesis:\n{hypothesis.detach().cpu().numpy()}')
  print(f'predicted:\n{predicted.detach().cpu().numpy()}')
  print(f'real y:\n{y.cpu().numpy()}')
  print(f'accuracy: {accuracy.item()}')

hypothesis:
[[2.7948347e-05]
 [9.9995732e-01]
 [9.9996006e-01]
 [5.7813952e-05]]
predicted:
[[0.]
 [1.]
 [1.]
 [0.]]
real y:
[[0.]
 [1.]
 [1.]
 [0.]]
accuracy: 1.0


In [14]:
tmp_model = torch.load('mymodel_0000.pt', map_location=device)

with torch.no_grad():
  hypothesis = tmp_model(x)
  predicted = (hypothesis > 0.5).float()
  accuracy = (predicted == y).float().mean()
  print(f'hypothesis:\n{hypothesis.detach().cpu().numpy()}')
  print(f'predicted:\n{predicted.detach().cpu().numpy()}')
  print(f'real y:\n{y.cpu().numpy()}')
  print(f'accuracy: {accuracy.item()}')

hypothesis:
[[9.4316805e-05]
 [9.9985385e-01]
 [9.9986172e-01]
 [1.9860093e-04]]
predicted:
[[0.]
 [1.]
 [1.]
 [0.]]
real y:
[[0.]
 [1.]
 [1.]
 [0.]]
accuracy: 1.0
